From 93812b7a032f7d999530b3c9e3b157fa9ce241c6 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 6 Mar 2024 20:17:34 -0600 Subject: [PATCH] Fixup fp16 pipeline --- .../custom_models/sdxl_inference/sdxl_pipeline.py | 3 ++- .../custom_models/sdxl_inference/sdxl_scheduled_unet.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline.py index a3ae609cf..884c979ee 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline.py @@ -270,6 +270,7 @@ def generate_images(args, vmfbs: dict, weights: dict): dtype=iree_dtype, ), ] + print(unet_inputs) latents = pipe_runner.ctx.modules.sdxl_compiled_pipeline["produce_image_latents"]( *unet_inputs, ) @@ -289,7 +290,7 @@ def generate_images(args, vmfbs: dict, weights: dict): image[0].save(img_path) print(img_path, "saved") print("Pipeline arguments: ", args) - print("Total time: ", pipe_end - pipe_start, "sec") + print("Total time: ", pipe_end - clip_start, "sec") print("Loading time: ", clip_start - pipe_start, "sec") print("Clip time: ", unet_start - clip_start, "sec") print("UNet time(", args.num_inference_steps, "): ", vae_start - unet_start, "sec,") diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py index 45a9c7a6f..819e57a83 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py @@ -76,7 +76,8 @@ def initialize(self, sample): add_time_ids = add_time_ids.repeat(sample.shape[0], 1).type(self.dtype) timesteps = self.scheduler.timesteps step_indexes = torch.tensor(len(timesteps)) - return sample * self.scheduler.init_noise_sigma, add_time_ids, step_indexes + sample = sample * self.scheduler.init_noise_sigma + return sample.type(self.dtype), add_time_ids, step_indexes def forward( self, sample, prompt_embeds, text_embeds, time_ids, guidance_scale, step_index @@ -103,7 +104,7 @@ def forward( noise_pred_text - noise_pred_uncond ) sample = self.scheduler.step(noise_pred, t, sample, return_dict=False)[0] - return sample + return sample.type(self.dtype) def export_scheduled_unet_model(