Skip to content

Commit

Permalink
Support llava clip vision model.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Mar 6, 2025
1 parent 85ef295 commit 0bef826
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 3 deletions.
20 changes: 19 additions & 1 deletion comfy/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output

class LlavaProjector(torch.nn.Module):
def __init__(self, in_dim, out_dim, dtype, device, operations):
super().__init__()
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)

def forward(self, x):
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))

class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
Expand All @@ -220,7 +229,16 @@ def __init__(self, config_dict, dtype, device, operations):
else:
self.visual_projection = lambda a: a

if "llava3" == config_dict.get("projector_type", None):
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
else:
self.multi_modal_projector = None

def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2])
return (x[0], x[1], out)
projected = None
if self.multi_modal_projector is not None:
projected = self.multi_modal_projector(x[1])

return (x[0], x[1], out, projected)
6 changes: 5 additions & 1 deletion comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def encode_image(self, image, crop=True):
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
outputs["mm_projected"] = out[3]
return outputs

def convert_to_transformers(sd, prefix):
Expand Down Expand Up @@ -104,7 +105,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
else:
Expand Down
19 changes: 19 additions & 0 deletions comfy/clip_vision_config_vitl_336_llava.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-5,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"projector_type": "llava3",
"torch_dtype": "float32"
}
19 changes: 18 additions & 1 deletion comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,25 @@ def process_tokens(self, tokens, device):
index = 0
pad_extra = 0
for o in other_embeds:
emb = o[1]
if torch.is_tensor(emb):
emb = {"type": "embedding", "data": emb}

emb_type = emb.get("type", None)
if emb_type == "embedding":
emb = emb.get("data", None)
else:
if hasattr(self.transformer, "preprocess_embed"):
emb = self.transformer.preprocess_embed(emb, device=device)
else:
emb = None

if emb is None:
index += -1
continue

ind = index + o[0]
emb = o[1].view(1, -1, o[1].shape[-1]).to(device=device, dtype=torch.float32)
emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
emb_shape = emb.shape[1]
if emb.shape[-1] == tokens_embed.shape[-1]:
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
Expand Down

0 comments on commit 0bef826

Please sign in to comment.