diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline.py index af2677075..0974b60c3 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline.py @@ -203,7 +203,8 @@ def export_submodel(args, submodel): def generate_images(args, vmfbs: dict, weights: dict): pipe_start = time.time() - dtype = torch.float16 if args.precision == "fp16" else torch.float32 + iree_dtype = "float32" if args.precision == "fp32" else "float16" + torch_dtype = torch.float32 if args.precision == "fp32" else torch.float16 all_imgs = [] generator = torch.manual_seed(0) @@ -215,7 +216,7 @@ def generate_images(args, vmfbs: dict, weights: dict): args.width // 8, ), generator=generator, - dtype=dtype, + dtype=torch_dtype, ) pipe_runner = vmfbRunner( @@ -250,19 +251,23 @@ def generate_images(args, vmfbs: dict, weights: dict): prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([pooled_negative_prompt_embeds, add_text_embeds], dim=0) - add_text_embeds = add_text_embeds.to(dtype) - prompt_embeds = prompt_embeds.to(dtype) + add_text_embeds = add_text_embeds.to(torch_dtype) + prompt_embeds = prompt_embeds.to(torch_dtype) unet_start = time.time() unet_inputs = [ - ireert.asdevicearray(pipe_runner.config.device, rand_sample), - ireert.asdevicearray(pipe_runner.config.device, prompt_embeds), - ireert.asdevicearray(pipe_runner.config.device, add_text_embeds), + ireert.asdevicearray(pipe_runner.config.device, rand_sample, dtype=iree_dtype), + ireert.asdevicearray( + pipe_runner.config.device, prompt_embeds, dtype=iree_dtype + ), + ireert.asdevicearray( + pipe_runner.config.device, add_text_embeds, dtype=iree_dtype + ), ireert.asdevicearray( pipe_runner.config.device, np.asarray([args.guidance_scale]), - dtype="float32" if args.precision == "fp32" else "float16", + dtype=iree_dtype, ), ] latents = pipe_runner.ctx.modules.sdxl_compiled_pipeline["produce_image_latents"]( @@ -395,6 +400,5 @@ def is_prepared(args, vmfbs, weights): vmfbs[submodel] = vmfb if weights[submodel] is None: weights[submodel] = weight - assert is_prepared(args, vmfbs, weights)[0] generate_images(args, vmfbs, weights) print("Image generation complete.")