Skip to content

Commit

Permalink
Support multiple text encoder configurations on SD3.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 11, 2024
1 parent 1c34d33 commit 5889b7c
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 33 deletions.
2 changes: 1 addition & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
vae = VAE(sd=vae_sd)

if output_clip:
clip_target = model_config.clip_target()
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
Expand Down
85 changes: 63 additions & 22 deletions comfy/sd3_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import os
import comfy.model_management
import logging

class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
Expand Down Expand Up @@ -43,42 +44,82 @@ def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)

class SD3ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
def __init__(self, clip_l=True, clip_g=True, t5=True, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
else:
self.clip_l = None

if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
else:
self.clip_g = None

if t5:
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
else:
self.t5xxl = None

logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}".format(clip_l, clip_g, t5))

def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.clip_g.set_clip_options(options)
self.t5xxl.set_clip_options(options)
if self.clip_l is not None:
self.clip_l.set_clip_options(options)
if self.clip_g is not None:
self.clip_g.set_clip_options(options)
if self.t5xxl is not None:
self.t5xxl.set_clip_options(options)

def reset_clip_options(self):
self.clip_g.reset_clip_options()
self.clip_l.reset_clip_options()
self.t5xxl.reset_clip_options()
if self.clip_l is not None:
self.clip_l.reset_clip_options()
if self.clip_g is not None:
self.clip_g.reset_clip_options()
if self.t5xxl is not None:
self.t5xxl.reset_clip_options()

def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
lg_out = None
pooled = None
out = None

if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
out = lg_out
if self.clip_l is not None:
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
else:
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())

if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1)
else:
lg_out = torch.nn.functional.pad(g_out, (768, 0))
else:
g_out = None
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())

if lg_out is not None:
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
out = lg_out
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
else:
pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device())

t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:
out = t5_out
if self.t5xxl is not None:
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:
out = t5_out

if out is None:
out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device())

if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())

return out, pooled

Expand Down
35 changes: 25 additions & 10 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)

class SD20(supported_models_base.BASE):
Expand Down Expand Up @@ -97,7 +97,7 @@ def process_clip_state_dict_for_saving(self, state_dict):
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict

def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)

class SD21UnclipL(SD20):
Expand Down Expand Up @@ -159,7 +159,7 @@ def process_clip_state_dict_for_saving(self, state_dict):
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g

def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)

class SDXL(supported_models_base.BASE):
Expand Down Expand Up @@ -228,7 +228,7 @@ def process_clip_state_dict_for_saving(self, state_dict):
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g

def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)

class SSD1B(SDXL):
Expand Down Expand Up @@ -299,7 +299,7 @@ def get_model(self, state_dict, prefix="", device=None):
out = model_base.SVD_img2vid(self, device=device)
return out

def clip_target(self):
def clip_target(self, state_dict={}):
return None

class SV3D_u(SVD_img2vid):
Expand Down Expand Up @@ -365,7 +365,7 @@ def get_model(self, state_dict, prefix="", device=None):
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
return out

def clip_target(self):
def clip_target(self, state_dict={}):
return None

class SD_X4Upscaler(SD20):
Expand Down Expand Up @@ -439,7 +439,7 @@ def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_C(self, device=device)
return out

def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)

class Stable_Cascade_B(Stable_Cascade_C):
Expand Down Expand Up @@ -501,14 +501,29 @@ class SD3(supported_models_base.BASE):

unet_extra_config = {}
latent_format = latent_formats.SD3
text_encoder_key_prefix = ["text_encoders."] #TODO?
text_encoder_key_prefix = ["text_encoders."]

def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD3(self, device=device)
return out

def clip_target(self):
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO?
def clip_target(self, state_dict={}):
clip_l = False
clip_g = False
t5 = False
pref = self.text_encoder_key_prefix[0]
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_l = True
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_g = True
if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict:
t5 = True

class SD3ClipModel(sd3_clip.SD3ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype)

return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)


models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
Expand Down

0 comments on commit 5889b7c

Please sign in to comment.