diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py index 576ec3e92..a4cc008f0 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py @@ -267,17 +267,15 @@ def run_forward( def export_pipeline_module(args): - from turbine_models.custom_models.sdxl_inference.pipeline_ir import ( - sdxl_sched_unet_bench_f32, - sdxl_sched_unet_bench_f16, - sdxl_pipeline_bench_f32, - sdxl_pipeline_bench_f16, - ) + from turbine_models.custom_models.sdxl_inference.pipeline_ir import get_pipeline_ir - pipeline_file = ( - sdxl_sched_unet_bench_f32 - if args.precision == "fp32" - else sdxl_sched_unet_bench_f16 + pipeline_file = get_pipeline_ir( + args.width, + args.height, + args.precision, + args.batch_size, + args.max_length, + "unet_loop", ) pipeline_vmfb = utils.compile_to_vmfb( pipeline_file, @@ -288,8 +286,13 @@ def export_pipeline_module(args): return_path=True, mlir_source="str", ) - full_pipeline_file = ( - sdxl_pipeline_bench_f32 if args.precision == "fp32" else sdxl_pipeline_bench_f16 + full_pipeline_file = get_pipeline_ir( + args.width, + args.height, + args.precision, + args.batch_size, + args.max_length, + "tokens_to_image", ) full_pipeline_vmfb = utils.compile_to_vmfb( pipeline_file,