Skip to content

Commit

Permalink
Fix pipeline ir import from sdxl_scheduled_unet script
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed May 31, 2024
1 parent 657edab commit e14d074
Showing 1 changed file with 15 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -267,17 +267,15 @@ def run_forward(


def export_pipeline_module(args):
from turbine_models.custom_models.sdxl_inference.pipeline_ir import (
sdxl_sched_unet_bench_f32,
sdxl_sched_unet_bench_f16,
sdxl_pipeline_bench_f32,
sdxl_pipeline_bench_f16,
)
from turbine_models.custom_models.sdxl_inference.pipeline_ir import get_pipeline_ir

pipeline_file = (
sdxl_sched_unet_bench_f32
if args.precision == "fp32"
else sdxl_sched_unet_bench_f16
pipeline_file = get_pipeline_ir(
args.width,
args.height,
args.precision,
args.batch_size,
args.max_length,
"unet_loop",
)
pipeline_vmfb = utils.compile_to_vmfb(
pipeline_file,
Expand All @@ -288,8 +286,13 @@ def export_pipeline_module(args):
return_path=True,
mlir_source="str",
)
full_pipeline_file = (
sdxl_pipeline_bench_f32 if args.precision == "fp32" else sdxl_pipeline_bench_f16
full_pipeline_file = get_pipeline_ir(
args.width,
args.height,
args.precision,
args.batch_size,
args.max_length,
"tokens_to_image",
)
full_pipeline_vmfb = utils.compile_to_vmfb(
pipeline_file,
Expand Down

0 comments on commit e14d074

Please sign in to comment.