Skip to content

Commit

Permalink
Fixup fp16 pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd authored and monorimet committed Mar 7, 2024
1 parent 7f51cb3 commit 93812b7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def generate_images(args, vmfbs: dict, weights: dict):
dtype=iree_dtype,
),
]
print(unet_inputs)
latents = pipe_runner.ctx.modules.sdxl_compiled_pipeline["produce_image_latents"](
*unet_inputs,
)
Expand All @@ -289,7 +290,7 @@ def generate_images(args, vmfbs: dict, weights: dict):
image[0].save(img_path)
print(img_path, "saved")
print("Pipeline arguments: ", args)
print("Total time: ", pipe_end - pipe_start, "sec")
print("Total time: ", pipe_end - clip_start, "sec")
print("Loading time: ", clip_start - pipe_start, "sec")
print("Clip time: ", unet_start - clip_start, "sec")
print("UNet time(", args.num_inference_steps, "): ", vae_start - unet_start, "sec,")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def initialize(self, sample):
add_time_ids = add_time_ids.repeat(sample.shape[0], 1).type(self.dtype)
timesteps = self.scheduler.timesteps
step_indexes = torch.tensor(len(timesteps))
return sample * self.scheduler.init_noise_sigma, add_time_ids, step_indexes
sample = sample * self.scheduler.init_noise_sigma
return sample.type(self.dtype), add_time_ids, step_indexes

def forward(
self, sample, prompt_embeds, text_embeds, time_ids, guidance_scale, step_index
Expand All @@ -103,7 +104,7 @@ def forward(
noise_pred_text - noise_pred_uncond
)
sample = self.scheduler.step(noise_pred, t, sample, return_dict=False)[0]
return sample
return sample.type(self.dtype)


def export_scheduled_unet_model(
Expand Down

0 comments on commit 93812b7

Please sign in to comment.