Skip to content

Commit

Permalink
Explicitly set dtypes based on precision argument
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Mar 6, 2024
1 parent c47605b commit 4b28a12
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def export_submodel(args, submodel):
pipeline_file = (
"sdxl_sched_unet_bench_" + "f32"
if args.precision == "fp32"
else "sdxl_sched_unet_bench" + "f16"
else "sdxl_sched_unet_bench_" + "f16"
)
pipeline_vmfb = utils.compile_to_vmfb(
os.path.join(
Expand All @@ -203,7 +203,8 @@ def export_submodel(args, submodel):

def generate_images(args, vmfbs: dict, weights: dict):
pipe_start = time.time()
dtype = torch.float16 if args.precision == "fp16" else torch.float32
iree_dtype = "float32" if args.precision == "fp32" else "float16"
torch_dtype = torch.float32 if args.precision == "fp32" else torch.float16

all_imgs = []
generator = torch.manual_seed(0)
Expand All @@ -215,7 +216,7 @@ def generate_images(args, vmfbs: dict, weights: dict):
args.width // 8,
),
generator=generator,
dtype=dtype,
dtype=torch_dtype,
)

pipe_runner = vmfbRunner(
Expand Down Expand Up @@ -250,19 +251,23 @@ def generate_images(args, vmfbs: dict, weights: dict):
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([pooled_negative_prompt_embeds, add_text_embeds], dim=0)

add_text_embeds = add_text_embeds.to(dtype)
prompt_embeds = prompt_embeds.to(dtype)
add_text_embeds = add_text_embeds.to(torch_dtype)
prompt_embeds = prompt_embeds.to(torch_dtype)

unet_start = time.time()

unet_inputs = [
ireert.asdevicearray(pipe_runner.config.device, rand_sample),
ireert.asdevicearray(pipe_runner.config.device, prompt_embeds),
ireert.asdevicearray(pipe_runner.config.device, add_text_embeds),
ireert.asdevicearray(pipe_runner.config.device, rand_sample, dtype=iree_dtype),
ireert.asdevicearray(
pipe_runner.config.device, prompt_embeds, dtype=iree_dtype
),
ireert.asdevicearray(
pipe_runner.config.device, add_text_embeds, dtype=iree_dtype
),
ireert.asdevicearray(
pipe_runner.config.device,
np.asarray([args.guidance_scale]),
dtype="float32" if args.precision == "fp32" else "float16",
dtype=iree_dtype,
),
]
latents = pipe_runner.ctx.modules.sdxl_compiled_pipeline["produce_image_latents"](
Expand Down Expand Up @@ -395,6 +400,5 @@ def is_prepared(args, vmfbs, weights):
vmfbs[submodel] = vmfb
if weights[submodel] is None:
weights[submodel] = weight
assert is_prepared(args, vmfbs, weights)[0]
generate_images(args, vmfbs, weights)
print("Image generation complete.")

0 comments on commit 4b28a12

Please sign in to comment.