diff --git a/comfy/clip_model.py b/comfy/clip_model.py index cf5b58b62a3..300b09ec7de 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -97,8 +97,12 @@ def __init__(self, config_dict, dtype, device, operations): self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) - def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): - x = self.embeddings(input_tokens, dtype=dtype) + def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): + if embeds is not None: + x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device) + else: + x = self.embeddings(input_tokens, dtype=dtype) + mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) @@ -116,7 +120,10 @@ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, f if i is not None and final_layer_norm_intermediate: i = self.final_layer_norm(i) - pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),] + if num_tokens is not None: + pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))] + else: + pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),] return x, i, pooled_output class CLIPTextModel(torch.nn.Module): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 692ae05188b..77514753551 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -158,71 +158,75 @@ def reset_clip_options(self): self.layer_idx = self.options_default[1] self.return_projected_pooled = self.options_default[2] - def set_up_textual_embeddings(self, tokens, current_embeds): - out_tokens = [] - next_new_token = token_dict_size = current_embeds.weight.shape[0] - embedding_weights = [] + def process_tokens(self, tokens, device): + end_token = self.special_tokens.get("end", None) + if end_token is None: + cmp_token = self.special_tokens.get("pad", -1) + else: + cmp_token = end_token + + embeds_out = [] + attention_masks = [] + num_tokens = [] for x in tokens: + attention_mask = [] tokens_temp = [] + other_embeds = [] + eos = False + index = 0 for y in x: if isinstance(y, numbers.Integral): - tokens_temp += [int(y)] - else: - if y.shape[0] == current_embeds.weight.shape[1]: - embedding_weights += [y] - tokens_temp += [next_new_token] - next_new_token += 1 + if eos: + attention_mask.append(0) else: - logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1])) - while len(tokens_temp) < len(x): - tokens_temp += [self.special_tokens["pad"]] - out_tokens += [tokens_temp] - - n = token_dict_size - if len(embedding_weights) > 0: - new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) - new_embedding.weight[:token_dict_size] = current_embeds.weight - for x in embedding_weights: - new_embedding.weight[n] = x - n += 1 - self.transformer.set_input_embeddings(new_embedding) - - processed_tokens = [] - for x in out_tokens: - processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one - - return processed_tokens + attention_mask.append(1) + token = int(y) + tokens_temp += [token] + if not eos and token == cmp_token: + if end_token is None: + attention_mask[-1] = 0 + eos = True + else: + other_embeds.append((index, y)) + index += 1 + + tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long) + tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32) + index = 0 + pad_extra = 0 + for o in other_embeds: + ind = index + o[0] + emb = o[1].view(1, -1, o[1].shape[-1]).to(device=device, dtype=torch.float32) + emb_shape = emb.shape[1] + if emb.shape[-1] == tokens_embed.shape[-1]: + tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1) + attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:] + index += emb_shape - 1 + else: + index += -1 + pad_extra += emb_shape + logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1])) - def forward(self, tokens): - backup_embeds = self.transformer.get_input_embeddings() - device = backup_embeds.weight.device - tokens = self.set_up_textual_embeddings(tokens, backup_embeds) - tokens = torch.LongTensor(tokens).to(device) - - attention_mask = None - if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks: - attention_mask = torch.zeros_like(tokens) - end_token = self.special_tokens.get("end", None) - if end_token is None: - cmp_token = self.special_tokens.get("pad", -1) - else: - cmp_token = end_token + if pad_extra > 0: + padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32) + tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1) - for x in range(attention_mask.shape[0]): - for y in range(attention_mask.shape[1]): - attention_mask[x, y] = 1 - if tokens[x, y] == cmp_token: - if end_token is None: - attention_mask[x, y] = 0 - break + embeds_out.append(tokens_embed) + attention_masks.append(attention_mask) + num_tokens.append(sum(attention_mask)) + + return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens + + def forward(self, tokens): + device = self.transformer.get_input_embeddings().weight.device + embeds, attention_mask, num_tokens = self.process_tokens(tokens, device) attention_mask_model = None if self.enable_attention_masks: attention_mask_model = attention_mask - outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) - self.transformer.set_input_embeddings(backup_embeds) + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) if self.layer == "last": z = outputs[0].float() diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py index d4edd5aa58e..551b0316269 100644 --- a/comfy/text_encoders/bert.py +++ b/comfy/text_encoders/bert.py @@ -93,8 +93,11 @@ def __init__(self, vocab_size, max_position_embeddings, type_vocab_size, pad_tok self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device) - def forward(self, input_tokens, token_type_ids=None, dtype=None): - x = self.word_embeddings(input_tokens, out_dtype=dtype) + def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None): + if embeds is not None: + x = embeds + else: + x = self.word_embeddings(input_tokens, out_dtype=dtype) x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x) if token_type_ids is not None: x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype) @@ -113,8 +116,8 @@ def __init__(self, config_dict, dtype, device, operations): self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations) self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations) - def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): - x = self.embeddings(input_tokens, dtype=dtype) + def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype) mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 3f234015ae8..58710b2bf83 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -241,8 +241,11 @@ def __init__(self, config, device=None, dtype=None, ops=None): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): - x = self.embed_tokens(x, out_dtype=dtype) + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + if embeds is not None: + x = embeds + else: + x = self.embed_tokens(x, out_dtype=dtype) if self.normalize_in: x *= self.config.hidden_size ** 0.5 diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index df2b5b5cdc9..49f0ba4fe5e 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -239,8 +239,11 @@ def get_input_embeddings(self): def set_input_embeddings(self, embeddings): self.shared = embeddings - def forward(self, input_ids, *args, **kwargs): - x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) + def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs): + if input_ids is None: + x = embeds + else: + x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]: x = torch.nan_to_num(x) #Fix for fp8 T5 base - return self.encoder(x, *args, **kwargs) + return self.encoder(x, attention_mask=attention_mask, **kwargs)