Skip to content

Commit

Permalink
Rework prompt encoder export on aot.export API
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 21, 2024
1 parent f140926 commit 67e6558
Showing 1 changed file with 120 additions and 123 deletions.
243 changes: 120 additions & 123 deletions models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,111 +48,106 @@ def __init__(
def forward(
self, text_input_ids_1, text_input_ids_2, uncond_input_ids_1, uncond_input_ids_2
):
with torch.no_grad():
prompt_embeds_1 = self.text_encoder_model_1(
text_input_ids_1,
output_hidden_states=True,
)
prompt_embeds_2 = self.text_encoder_model_2(
text_input_ids_2,
output_hidden_states=True,
)
neg_prompt_embeds_1 = self.text_encoder_model_1(
uncond_input_ids_1,
output_hidden_states=True,
)
neg_prompt_embeds_2 = self.text_encoder_model_2(
uncond_input_ids_2,
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds_2[0]
neg_pooled_prompt_embeds = neg_prompt_embeds_2[0]
prompt_embeds_1 = self.text_encoder_model_1(
text_input_ids_1,
output_hidden_states=True,
)
prompt_embeds_2 = self.text_encoder_model_2(
text_input_ids_2,
output_hidden_states=True,
)
neg_prompt_embeds_1 = self.text_encoder_model_1(
uncond_input_ids_1,
output_hidden_states=True,
)
neg_prompt_embeds_2 = self.text_encoder_model_2(
uncond_input_ids_2,
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds_2[0]
neg_pooled_prompt_embeds = neg_prompt_embeds_2[0]

prompt_embeds_list = [
prompt_embeds_1.hidden_states[-2],
prompt_embeds_2.hidden_states[-2],
]
neg_prompt_embeds_list = [
neg_prompt_embeds_1.hidden_states[-2],
neg_prompt_embeds_2.hidden_states[-2],
]
prompt_embeds_list = [
prompt_embeds_1.hidden_states[-2],
prompt_embeds_2.hidden_states[-2],
]
neg_prompt_embeds_list = [
neg_prompt_embeds_1.hidden_states[-2],
neg_prompt_embeds_2.hidden_states[-2],
]

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
neg_prompt_embeds = torch.concat(neg_prompt_embeds_list, dim=-1)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
neg_prompt_embeds = torch.concat(neg_prompt_embeds_list, dim=-1)

bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(
bs_embed * 1, -1
)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1)
if not self.batch_input:
prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1)
add_text_embeds = pooled_prompt_embeds
if not self.batch_input:
add_text_embeds = add_text_embeds.repeat(self.batch_size, 1)
if self.do_classifier_free_guidance:
if not self.batch_input:
prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1)
add_text_embeds = pooled_prompt_embeds
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(1, 1).view(
1, -1
)
neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1)
neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1)
if not self.batch_input:
add_text_embeds = add_text_embeds.repeat(self.batch_size, 1)
if self.do_classifier_free_guidance:
if not self.batch_input:
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(
1, 1
).view(1, -1)
neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1)
neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1)
if not self.batch_input:
neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1)
prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0)
if not self.batch_input:
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(
self.batch_size, 1
)
add_text_embeds = torch.cat(
[neg_pooled_prompt_embeds, add_text_embeds], dim=0
neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1)
prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0)
if not self.batch_input:
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(
self.batch_size, 1
)
add_text_embeds = add_text_embeds.to(self.torch_dtype)
prompt_embeds = prompt_embeds.to(self.torch_dtype)
return prompt_embeds, add_text_embeds
add_text_embeds = torch.cat(
[neg_pooled_prompt_embeds, add_text_embeds], dim=0
)
add_text_embeds = add_text_embeds.to(self.torch_dtype)
prompt_embeds = prompt_embeds.to(self.torch_dtype)
return prompt_embeds, add_text_embeds

def forward_turbo(self, text_input_ids_1, text_input_ids_2):
with torch.no_grad():
prompt_embeds_1 = self.text_encoder_model_1(
text_input_ids_1,
output_hidden_states=True,
)
prompt_embeds_2 = self.text_encoder_model_2(
text_input_ids_2,
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds_2[0]
prompt_embeds_1 = self.text_encoder_model_1(
text_input_ids_1,
output_hidden_states=True,
)
prompt_embeds_2 = self.text_encoder_model_2(
text_input_ids_2,
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds_2[0]

prompt_embeds_list = [
prompt_embeds_1.hidden_states[-2],
prompt_embeds_2.hidden_states[-2],
]
# neg_prompt_embeds_list = [
# torch.zeros_like(prompt_embeds_list[0]), # dummy tensor
# torch.zeros_like(prompt_embeds_list[1]), # dummy tensor
# ]
prompt_embeds_list = [
prompt_embeds_1.hidden_states[-2],
prompt_embeds_2.hidden_states[-2],
]
# neg_prompt_embeds_list = [
# torch.zeros_like(prompt_embeds_list[0]), # dummy tensor
# torch.zeros_like(prompt_embeds_list[1]), # dummy tensor
# ]

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)

bs_embed, seq_len, _ = prompt_embeds.shape
bs_embed, seq_len, _ = prompt_embeds.shape

prompt_embeds = prompt_embeds.repeat(1, 1, 1)
prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(
bs_embed * 1, -1
)
prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1)
add_text_embeds = pooled_prompt_embeds
add_text_embeds = add_text_embeds.repeat(self.batch_size, 1)
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1)
prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1)
add_text_embeds = pooled_prompt_embeds
add_text_embeds = add_text_embeds.repeat(self.batch_size, 1)

add_text_embeds = add_text_embeds.to(self.torch_dtype)
prompt_embeds = prompt_embeds.to(self.torch_dtype)
return prompt_embeds, add_text_embeds
add_text_embeds = add_text_embeds.to(self.torch_dtype)
prompt_embeds = prompt_embeds.to(self.torch_dtype)
return prompt_embeds, add_text_embeds


@torch.no_grad()
def export_prompt_encoder(
hf_model_name,
hf_auth_token=None,
Expand Down Expand Up @@ -233,6 +228,20 @@ def export_prompt_encoder(
if weights_only:
return None, external_weight_path

example_inputs = {
"text_input_ids_1": torch.empty(
(input_batchsize, max_length), dtype=torch.int64
),
"text_input_ids_2": torch.empty(
(input_batchsize, max_length), dtype=torch.int64
),
"uncond_input_ids_1": torch.empty(
(input_batchsize, max_length), dtype=torch.int64
),
"uncond_input_ids_2": torch.empty(
(input_batchsize, max_length), dtype=torch.int64
),
}
decomp_list = []
if decomp_attn == True:
decomp_list = [
Expand All @@ -244,47 +253,35 @@ def export_prompt_encoder(
from_current=True,
add_ops=decomp_list,
):

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
prompt_encoder_module,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(prompt_encoder_module)

def encode_prompts(
self,
t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
):
return jittable(prompt_encoder_module.forward)(
t_ids_1, t_ids_2, uc_ids_1, uc_ids_2
)

def encode_prompts_turbo(
self,
t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
if external_weights:
# Transformers (model source) registers position ids as non-persistent.
# This causes externalization to think it's a user input, and since it's not,
# we end up trying to do ops on a !torch.None instead of a tensor.
for buffer_name, buffer in prompt_encoder_module.named_buffers(
recurse=True
):
return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledClip(context=Context(), import_to=import_to)

module = CompiledModule.get_mlir_module(inst)
mod_name_list = buffer_name.split(".")
buffer_id = mod_name_list.pop()
parent = prompt_encoder_module
for i in mod_name_list:
parent = getattr(parent, i)
parent.register_buffer(buffer_id, buffer, persistent=True)
externalize_module_parameters(prompt_encoder_module)
output = export(
prompt_encoder_module,
kwargs=example_inputs,
module_name="compiled_clip",
function_name="encode_prompts",
)
module = output.mlir_module

model_metadata_encode = {
"model_name": hf_model_name + "_text_encoder",
"input_shapes": [str((input_batchsize, max_length)) for i in range(4)],
"input_dtypes": ["int64" for i in range(4)],
"use_attention_mask": False,
}

module = AddMetadataPass(module, model_metadata_encode, "encode_prompts").run()
module_str = str(module)
if compile_to != "vmfb":
Expand Down

0 comments on commit 67e6558

Please sign in to comment.