diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py index c1c21301..368fb0d7 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py @@ -222,6 +222,13 @@ def is_valid_file(arg): ############################################################################## p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") +p.add_argument( + "--batch_prompt_input", + type=bool, + default=False, + help="If batch size > 1 this enables batching the prompt encoder input rather than concating prompt encoders output", +) + p.add_argument( "--height", type=int, default=1024, help="Height of Stable Diffusion output image." ) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py index 90b16d7b..ec88c525 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py @@ -68,6 +68,7 @@ def __init__( custom_vae: str = "", cpu_scheduling: bool = False, vae_precision: str = "fp32", + batch_prompt_input: bool = False, ): self.hf_model_name = hf_model_name self.scheduler_id = scheduler_id @@ -76,6 +77,7 @@ def __init__( self.precision = precision self.max_length = max_length self.batch_size = batch_size + self.batch_prompt_input = batch_prompt_input self.num_inference_steps = num_inference_steps self.devices = {} if isinstance(device, dict): @@ -375,7 +377,7 @@ def export_submodel( "vmfb", self.external_weights, unet_external_weight_path, - self.devices["unet"]["device"], + self.devices["unet"]["driver"], self.devices["unet"]["target"], self.ireec_flags["unet"], self.decomp_attn, @@ -404,7 +406,7 @@ def export_submodel( "vmfb", self.external_weights, unet_external_weight_path, - self.devices["unet"]["device"], + self.devices["unet"]["driver"], self.devices["unet"]["target"], self.ireec_flags["unet"], self.decomp_attn, @@ -429,7 +431,7 @@ def export_submodel( self.num_inference_steps, self.precision, "vmfb", - self.devices["unet"]["device"], + self.devices["unet"]["driver"], self.devices["unet"]["target"], self.ireec_flags["scheduler"], exit_on_vmfb=False, @@ -460,7 +462,7 @@ def export_submodel( "vmfb", self.external_weights, vae_external_weight_path, - self.devices["vae"]["device"], + self.devices["vae"]["driver"], self.devices["vae"]["target"], self.ireec_flags["vae"], "decode", @@ -482,7 +484,7 @@ def export_submodel( "vmfb", self.external_weights, prompt_encoder_external_weight_path, - self.devices["clip"]["device"], + self.devices["clip"]["driver"], self.devices["clip"]["target"], self.ireec_flags["clip"], exit_on_vmfb=False, @@ -492,7 +494,8 @@ def export_submodel( input_mlir=input_mlir["prompt_encoder"], attn_spec=self.attn_spec, weights_only=weights_only, - output_batchsize=self.batch_size, + batchsize=self.batch_size, + batch_input=self.batch_prompt_input, ) return prompt_encoder_vmfb, prompt_encoder_external_weight_path case "unetloop": @@ -514,7 +517,7 @@ def export_submodel( ] pipeline_vmfb = utils.compile_to_vmfb( pipeline_file, - self.devices["unet"]["device"], + self.devices["unet"]["driver"], self.devices["unet"]["target"], self.ireec_flags["unetloop"], os.path.join(self.pipeline_dir, "_".join(pipeline_keys)), @@ -541,7 +544,7 @@ def export_submodel( ] pipeline_vmfb = utils.compile_to_vmfb( pipeline_file, - self.devices["unet"]["device"], + self.devices["unet"]["driver"], self.devices["unet"]["target"], self.ireec_flags["unetloop"], os.path.join(self.pipeline_dir, "_".join(pipeline_keys)), @@ -1082,6 +1085,7 @@ def numpy_to_pil_image(images): args.vae_decomp_attn, custom_vae=None, vae_precision=args.vae_precision, + batch_prompt_input=args.batch_prompt_input, ) vmfbs, weights = sdxl_pipe.check_prepared(mlirs, vmfbs, weights) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 0c72d5c9..50d2d026 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -27,6 +27,7 @@ def __init__( hf_auth_token=None, do_classifier_free_guidance=True, batch_size=1, + batch_input=False, ): super().__init__() self.torch_dtype = torch.float16 if precision == "fp16" else torch.float32 @@ -42,6 +43,7 @@ def __init__( ) self.do_classifier_free_guidance = True self.batch_size = batch_size + self.batch_input = batch_input def forward( self, text_input_ids_1, text_input_ids_2, uncond_input_ids_1, uncond_input_ids_2 @@ -85,20 +87,25 @@ def forward( pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( bs_embed * 1, -1 ) - prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) + if not self.batch_input: + prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) add_text_embeds = pooled_prompt_embeds - add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) + if not self.batch_input: + add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) if self.do_classifier_free_guidance: - neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(1, 1).view( - 1, -1 - ) + if not self.batch_input: + neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( + 1, 1 + ).view(1, -1) neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1) neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1) - neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1) + if not self.batch_input: + neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1) prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0) - neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( - self.batch_size, 1 - ) + if not self.batch_input: + neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( + self.batch_size, 1 + ) add_text_embeds = torch.cat( [neg_pooled_prompt_embeds, add_text_embeds], dim=0 ) @@ -163,6 +170,7 @@ def export_prompt_encoder( input_mlir=None, attn_spec=None, weights_only=False, + batch_input=False, ): do_classifier_free_guidance = True @@ -206,7 +214,13 @@ def export_prompt_encoder( hf_auth_token, do_classifier_free_guidance, batch_size=batch_size, + batch_input=batch_input, ) + + input_batchsize = 1 + if batch_input: + input_batchsize = batch_size + if precision == "fp16": prompt_encoder_module = prompt_encoder_module.half() mapper = {} @@ -231,10 +245,10 @@ class CompiledClip(CompiledModule): def encode_prompts( self, - t_ids_1=AbstractTensor(1, max_length, dtype=torch.int64), - t_ids_2=AbstractTensor(1, max_length, dtype=torch.int64), - uc_ids_1=AbstractTensor(1, max_length, dtype=torch.int64), - uc_ids_2=AbstractTensor(1, max_length, dtype=torch.int64), + t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), ): return jittable(prompt_encoder_module.forward)( t_ids_1, t_ids_2, uc_ids_1, uc_ids_2 @@ -242,8 +256,8 @@ def encode_prompts( def encode_prompts_turbo( self, - t_ids_1=AbstractTensor(1, max_length, dtype=torch.int64), - t_ids_2=AbstractTensor(1, max_length, dtype=torch.int64), + t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), ): return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2)