diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 17328f6e..1ae7ea1f 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -48,111 +48,106 @@ def __init__( def forward( self, text_input_ids_1, text_input_ids_2, uncond_input_ids_1, uncond_input_ids_2 ): - with torch.no_grad(): - prompt_embeds_1 = self.text_encoder_model_1( - text_input_ids_1, - output_hidden_states=True, - ) - prompt_embeds_2 = self.text_encoder_model_2( - text_input_ids_2, - output_hidden_states=True, - ) - neg_prompt_embeds_1 = self.text_encoder_model_1( - uncond_input_ids_1, - output_hidden_states=True, - ) - neg_prompt_embeds_2 = self.text_encoder_model_2( - uncond_input_ids_2, - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds_2[0] - neg_pooled_prompt_embeds = neg_prompt_embeds_2[0] + prompt_embeds_1 = self.text_encoder_model_1( + text_input_ids_1, + output_hidden_states=True, + ) + prompt_embeds_2 = self.text_encoder_model_2( + text_input_ids_2, + output_hidden_states=True, + ) + neg_prompt_embeds_1 = self.text_encoder_model_1( + uncond_input_ids_1, + output_hidden_states=True, + ) + neg_prompt_embeds_2 = self.text_encoder_model_2( + uncond_input_ids_2, + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds_2[0] + neg_pooled_prompt_embeds = neg_prompt_embeds_2[0] - prompt_embeds_list = [ - prompt_embeds_1.hidden_states[-2], - prompt_embeds_2.hidden_states[-2], - ] - neg_prompt_embeds_list = [ - neg_prompt_embeds_1.hidden_states[-2], - neg_prompt_embeds_2.hidden_states[-2], - ] + prompt_embeds_list = [ + prompt_embeds_1.hidden_states[-2], + prompt_embeds_2.hidden_states[-2], + ] + neg_prompt_embeds_list = [ + neg_prompt_embeds_1.hidden_states[-2], + neg_prompt_embeds_2.hidden_states[-2], + ] - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - neg_prompt_embeds = torch.concat(neg_prompt_embeds_list, dim=-1) + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + neg_prompt_embeds = torch.concat(neg_prompt_embeds_list, dim=-1) - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, 1, 1) - prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( - bs_embed * 1, -1 - ) + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, 1, 1) + prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1) + if not self.batch_input: + prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) + add_text_embeds = pooled_prompt_embeds + if not self.batch_input: + add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) + if self.do_classifier_free_guidance: if not self.batch_input: - prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) - add_text_embeds = pooled_prompt_embeds + neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(1, 1).view( + 1, -1 + ) + neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1) + neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1) if not self.batch_input: - add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) - if self.do_classifier_free_guidance: - if not self.batch_input: - neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( - 1, 1 - ).view(1, -1) - neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1) - neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1) - if not self.batch_input: - neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1) - prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0) - if not self.batch_input: - neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( - self.batch_size, 1 - ) - add_text_embeds = torch.cat( - [neg_pooled_prompt_embeds, add_text_embeds], dim=0 + neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1) + prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0) + if not self.batch_input: + neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( + self.batch_size, 1 ) - add_text_embeds = add_text_embeds.to(self.torch_dtype) - prompt_embeds = prompt_embeds.to(self.torch_dtype) - return prompt_embeds, add_text_embeds + add_text_embeds = torch.cat( + [neg_pooled_prompt_embeds, add_text_embeds], dim=0 + ) + add_text_embeds = add_text_embeds.to(self.torch_dtype) + prompt_embeds = prompt_embeds.to(self.torch_dtype) + return prompt_embeds, add_text_embeds def forward_turbo(self, text_input_ids_1, text_input_ids_2): - with torch.no_grad(): - prompt_embeds_1 = self.text_encoder_model_1( - text_input_ids_1, - output_hidden_states=True, - ) - prompt_embeds_2 = self.text_encoder_model_2( - text_input_ids_2, - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds_2[0] + prompt_embeds_1 = self.text_encoder_model_1( + text_input_ids_1, + output_hidden_states=True, + ) + prompt_embeds_2 = self.text_encoder_model_2( + text_input_ids_2, + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds_2[0] - prompt_embeds_list = [ - prompt_embeds_1.hidden_states[-2], - prompt_embeds_2.hidden_states[-2], - ] - # neg_prompt_embeds_list = [ - # torch.zeros_like(prompt_embeds_list[0]), # dummy tensor - # torch.zeros_like(prompt_embeds_list[1]), # dummy tensor - # ] + prompt_embeds_list = [ + prompt_embeds_1.hidden_states[-2], + prompt_embeds_2.hidden_states[-2], + ] + # neg_prompt_embeds_list = [ + # torch.zeros_like(prompt_embeds_list[0]), # dummy tensor + # torch.zeros_like(prompt_embeds_list[1]), # dummy tensor + # ] - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - bs_embed, seq_len, _ = prompt_embeds.shape + bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, 1, 1) - prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( - bs_embed * 1, -1 - ) - prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) - add_text_embeds = pooled_prompt_embeds - add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) + prompt_embeds = prompt_embeds.repeat(1, 1, 1) + prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1) + prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) + add_text_embeds = pooled_prompt_embeds + add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) - add_text_embeds = add_text_embeds.to(self.torch_dtype) - prompt_embeds = prompt_embeds.to(self.torch_dtype) - return prompt_embeds, add_text_embeds + add_text_embeds = add_text_embeds.to(self.torch_dtype) + prompt_embeds = prompt_embeds.to(self.torch_dtype) + return prompt_embeds, add_text_embeds +@torch.no_grad() def export_prompt_encoder( hf_model_name, hf_auth_token=None, @@ -233,6 +228,20 @@ def export_prompt_encoder( if weights_only: return None, external_weight_path + example_inputs = { + "text_input_ids_1": torch.empty( + (input_batchsize, max_length), dtype=torch.int64 + ), + "text_input_ids_2": torch.empty( + (input_batchsize, max_length), dtype=torch.int64 + ), + "uncond_input_ids_1": torch.empty( + (input_batchsize, max_length), dtype=torch.int64 + ), + "uncond_input_ids_2": torch.empty( + (input_batchsize, max_length), dtype=torch.int64 + ), + } decomp_list = [] if decomp_attn == True: decomp_list = [ @@ -244,40 +253,27 @@ def export_prompt_encoder( from_current=True, add_ops=decomp_list, ): - - class CompiledClip(CompiledModule): - if external_weights: - params = export_parameters( - prompt_encoder_module, - external=True, - external_scope="", - name_mapper=mapper.get, - ) - else: - params = export_parameters(prompt_encoder_module) - - def encode_prompts( - self, - t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - ): - return jittable(prompt_encoder_module.forward)( - t_ids_1, t_ids_2, uc_ids_1, uc_ids_2 - ) - - def encode_prompts_turbo( - self, - t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + if external_weights: + # Transformers (model source) registers position ids as non-persistent. + # This causes externalization to think it's a user input, and since it's not, + # we end up trying to do ops on a !torch.None instead of a tensor. + for buffer_name, buffer in prompt_encoder_module.named_buffers( + recurse=True ): - return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2) - - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledClip(context=Context(), import_to=import_to) - - module = CompiledModule.get_mlir_module(inst) + mod_name_list = buffer_name.split(".") + buffer_id = mod_name_list.pop() + parent = prompt_encoder_module + for i in mod_name_list: + parent = getattr(parent, i) + parent.register_buffer(buffer_id, buffer, persistent=True) + externalize_module_parameters(prompt_encoder_module) + output = export( + prompt_encoder_module, + kwargs=example_inputs, + module_name="compiled_clip", + function_name="encode_prompts", + ) + module = output.mlir_module model_metadata_encode = { "model_name": hf_model_name + "_text_encoder", @@ -285,6 +281,7 @@ def encode_prompts_turbo( "input_dtypes": ["int64" for i in range(4)], "use_attention_mask": False, } + module = AddMetadataPass(module, model_metadata_encode, "encode_prompts").run() module_str = str(module) if compile_to != "vmfb":