From 0bef826a98dc93d59bb5f260175e449d587cf923 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Mar 2025 00:24:43 -0500 Subject: [PATCH] Support llava clip vision model. --- comfy/clip_model.py | 20 +++++++++++++++++++- comfy/clip_vision.py | 6 +++++- comfy/clip_vision_config_vitl_336_llava.json | 19 +++++++++++++++++++ comfy/sd1_clip.py | 19 ++++++++++++++++++- 4 files changed, 61 insertions(+), 3 deletions(-) create mode 100644 comfy/clip_vision_config_vitl_336_llava.json diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 300b09ec7de..c8294d4832e 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -211,6 +211,15 @@ def forward(self, pixel_values, attention_mask=None, intermediate_output=None): pooled_output = self.post_layernorm(x[:, 0, :]) return x, i, pooled_output +class LlavaProjector(torch.nn.Module): + def __init__(self, in_dim, out_dim, dtype, device, operations): + super().__init__() + self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype) + self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype) + + def forward(self, x): + return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:]))) + class CLIPVisionModelProjection(torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() @@ -220,7 +229,16 @@ def __init__(self, config_dict, dtype, device, operations): else: self.visual_projection = lambda a: a + if "llava3" == config_dict.get("projector_type", None): + self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations) + else: + self.multi_modal_projector = None + def forward(self, *args, **kwargs): x = self.vision_model(*args, **kwargs) out = self.visual_projection(x[2]) - return (x[0], x[1], out) + projected = None + if self.multi_modal_projector is not None: + projected = self.multi_modal_projector(x[1]) + + return (x[0], x[1], out, projected) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index c9c82e9ade0..297b3bca3e0 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -65,6 +65,7 @@ def encode_image(self, image, crop=True): outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) + outputs["mm_projected"] = out[3] return outputs def convert_to_transformers(sd, prefix): @@ -104,7 +105,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") + if "multi_modal_projector.linear_1.bias" in sd: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json") + else: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") else: diff --git a/comfy/clip_vision_config_vitl_336_llava.json b/comfy/clip_vision_config_vitl_336_llava.json new file mode 100644 index 00000000000..f23a50d8b77 --- /dev/null +++ b/comfy/clip_vision_config_vitl_336_llava.json @@ -0,0 +1,19 @@ +{ + "attention_dropout": 0.0, + "dropout": 0.0, + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "image_size": 336, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-5, + "model_type": "clip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "projector_type": "llava3", + "torch_dtype": "float32" +} diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 77514753551..22adcbac9a8 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -196,8 +196,25 @@ def process_tokens(self, tokens, device): index = 0 pad_extra = 0 for o in other_embeds: + emb = o[1] + if torch.is_tensor(emb): + emb = {"type": "embedding", "data": emb} + + emb_type = emb.get("type", None) + if emb_type == "embedding": + emb = emb.get("data", None) + else: + if hasattr(self.transformer, "preprocess_embed"): + emb = self.transformer.preprocess_embed(emb, device=device) + else: + emb = None + + if emb is None: + index += -1 + continue + ind = index + o[0] - emb = o[1].view(1, -1, o[1].shape[-1]).to(device=device, dtype=torch.float32) + emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32) emb_shape = emb.shape[1] if emb.shape[-1] == tokens_embed.shape[-1]: tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)