Skip to content

Commit

Permalink
Explicitly set dtypes based on precision argument
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Mar 6, 2024
1 parent c47605b commit 56a8bfe
Showing 1 changed file with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def export_submodel(args, submodel):

def generate_images(args, vmfbs: dict, weights: dict):
pipe_start = time.time()
dtype = torch.float16 if args.precision == "fp16" else torch.float32
iree_dtype = "float32" if args.precision == "fp32" else "float16"
torch_dtype = torch.float32 if args.precision == "fp32" else torch.float16

all_imgs = []
generator = torch.manual_seed(0)
Expand All @@ -215,7 +216,7 @@ def generate_images(args, vmfbs: dict, weights: dict):
args.width // 8,
),
generator=generator,
dtype=dtype,
dtype=torch_dtype,
)

pipe_runner = vmfbRunner(
Expand Down Expand Up @@ -250,19 +251,23 @@ def generate_images(args, vmfbs: dict, weights: dict):
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([pooled_negative_prompt_embeds, add_text_embeds], dim=0)

add_text_embeds = add_text_embeds.to(dtype)
prompt_embeds = prompt_embeds.to(dtype)
add_text_embeds = add_text_embeds.to(torch_dtype)
prompt_embeds = prompt_embeds.to(torch_dtype)

unet_start = time.time()

unet_inputs = [
ireert.asdevicearray(pipe_runner.config.device, rand_sample),
ireert.asdevicearray(pipe_runner.config.device, prompt_embeds),
ireert.asdevicearray(pipe_runner.config.device, add_text_embeds),
ireert.asdevicearray(pipe_runner.config.device, rand_sample, dtype=iree_dtype),
ireert.asdevicearray(
pipe_runner.config.device, prompt_embeds, dtype=iree_dtype
),
ireert.asdevicearray(
pipe_runner.config.device, add_text_embeds, dtype=iree_dtype
),
ireert.asdevicearray(
pipe_runner.config.device,
np.asarray([args.guidance_scale]),
dtype="float32" if args.precision == "fp32" else "float16",
dtype=iree_dtype,
),
]
latents = pipe_runner.ctx.modules.sdxl_compiled_pipeline["produce_image_latents"](
Expand Down Expand Up @@ -395,6 +400,5 @@ def is_prepared(args, vmfbs, weights):
vmfbs[submodel] = vmfb
if weights[submodel] is None:
weights[submodel] = weight
assert is_prepared(args, vmfbs, weights)[0]
generate_images(args, vmfbs, weights)
print("Image generation complete.")

0 comments on commit 56a8bfe

Please sign in to comment.