Skip to content

Commit

Permalink
Fix vae decode export case returning tuple.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Mar 7, 2024
1 parent 93812b7 commit 7776cb5
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,27 +125,25 @@ def export_submodel(args, submodel):
)
return unet_vmfb, unet_external_weight_path
case "vae_decode":
return (
vae.export_vae_model(
vae_torch,
args.hf_model_name,
args.batch_size,
args.height,
args.width,
args.precision,
"vmfb",
args.external_weights,
vae_external_weight_path,
args.device,
args.iree_target_triple,
args.ireec_flags + args.attn_flags,
"decode",
args.decomp_attn,
exit_on_vmfb=False,
pipeline_dir=args.pipeline_dir,
),
vae_decode_vmfb, vae_external_weight_path = vae.export_vae_model(
vae_torch,
args.hf_model_name,
args.batch_size,
args.height,
args.width,
args.precision,
"vmfb",
args.external_weights,
vae_external_weight_path,
args.device,
args.iree_target_triple,
args.ireec_flags + args.attn_flags,
"decode",
args.decomp_attn,
exit_on_vmfb=False,
pipeline_dir=args.pipeline_dir,
)
return vae_decode_vmfb, vae_external_weight_path
case "clip_1":
clip_1_vmfb, _ = clip.export_clip_model(
args.hf_model_name,
Expand Down Expand Up @@ -224,6 +222,7 @@ def generate_images(args, vmfbs: dict, weights: dict):
[vmfbs["scheduled_unet"], vmfbs["pipeline"]],
[weights["scheduled_unet"], None],
)
breakpoint()
vae_decode_runner = vmfbRunner(
args.rt_device, vmfbs["vae_decode"], weights["vae_decode"]
)
Expand Down

0 comments on commit 7776cb5

Please sign in to comment.