Skip to content

Commit

Permalink
Make applying embeddings more efficient.
Browse files Browse the repository at this point in the history
Adding new tokens no longer makes a whole copy of the embeddings weight
which can be massive on certain models.
  • Loading branch information
comfyanonymous committed Mar 5, 2025
1 parent 5d84607 commit 85ef295
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 64 deletions.
13 changes: 10 additions & 3 deletions comfy/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,12 @@ def __init__(self, config_dict, dtype, device, operations):
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)

def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
x = self.embeddings(input_tokens, dtype=dtype)
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else:
x = self.embeddings(input_tokens, dtype=dtype)

mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
Expand All @@ -116,7 +120,10 @@ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, f
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)

pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
if num_tokens is not None:
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
else:
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
return x, i, pooled_output

class CLIPTextModel(torch.nn.Module):
Expand Down
108 changes: 56 additions & 52 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,71 +158,75 @@ def reset_clip_options(self):
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]

def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0]
embedding_weights = []
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
if end_token is None:
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token

embeds_out = []
attention_masks = []
num_tokens = []

for x in tokens:
attention_mask = []
tokens_temp = []
other_embeds = []
eos = False
index = 0
for y in x:
if isinstance(y, numbers.Integral):
tokens_temp += [int(y)]
else:
if y.shape[0] == current_embeds.weight.shape[1]:
embedding_weights += [y]
tokens_temp += [next_new_token]
next_new_token += 1
if eos:
attention_mask.append(0)
else:
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
while len(tokens_temp) < len(x):
tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp]

n = token_dict_size
if len(embedding_weights) > 0:
new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight
for x in embedding_weights:
new_embedding.weight[n] = x
n += 1
self.transformer.set_input_embeddings(new_embedding)

processed_tokens = []
for x in out_tokens:
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one

return processed_tokens
attention_mask.append(1)
token = int(y)
tokens_temp += [token]
if not eos and token == cmp_token:
if end_token is None:
attention_mask[-1] = 0
eos = True
else:
other_embeds.append((index, y))
index += 1

tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long)
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
index = 0
pad_extra = 0
for o in other_embeds:
ind = index + o[0]
emb = o[1].view(1, -1, o[1].shape[-1]).to(device=device, dtype=torch.float32)
emb_shape = emb.shape[1]
if emb.shape[-1] == tokens_embed.shape[-1]:
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
index += emb_shape - 1
else:
index += -1
pad_extra += emb_shape
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))

def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)

attention_mask = None
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", None)
if end_token is None:
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token
if pad_extra > 0:
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)

for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == cmp_token:
if end_token is None:
attention_mask[x, y] = 0
break
embeds_out.append(tokens_embed)
attention_masks.append(attention_mask)
num_tokens.append(sum(attention_mask))

return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens

def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)

attention_mask_model = None
if self.enable_attention_masks:
attention_mask_model = attention_mask

outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
self.transformer.set_input_embeddings(backup_embeds)
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)

if self.layer == "last":
z = outputs[0].float()
Expand Down
11 changes: 7 additions & 4 deletions comfy/text_encoders/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ def __init__(self, vocab_size, max_position_embeddings, type_vocab_size, pad_tok

self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)

def forward(self, input_tokens, token_type_ids=None, dtype=None):
x = self.word_embeddings(input_tokens, out_dtype=dtype)
def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None):
if embeds is not None:
x = embeds
else:
x = self.word_embeddings(input_tokens, out_dtype=dtype)
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
if token_type_ids is not None:
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
Expand All @@ -113,8 +116,8 @@ def __init__(self, config_dict, dtype, device, operations):
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)

def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embeddings(input_tokens, dtype=dtype)
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
Expand Down
7 changes: 5 additions & 2 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,11 @@ def __init__(self, config, device=None, dtype=None, ops=None):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)

def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embed_tokens(x, out_dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
if embeds is not None:
x = embeds
else:
x = self.embed_tokens(x, out_dtype=dtype)

if self.normalize_in:
x *= self.config.hidden_size ** 0.5
Expand Down
9 changes: 6 additions & 3 deletions comfy/text_encoders/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,11 @@ def get_input_embeddings(self):
def set_input_embeddings(self, embeddings):
self.shared = embeddings

def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs):
if input_ids is None:
x = embeds
else:
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = torch.nan_to_num(x) #Fix for fp8 T5 base
return self.encoder(x, *args, **kwargs)
return self.encoder(x, attention_mask=attention_mask, **kwargs)

0 comments on commit 85ef295

Please sign in to comment.