diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 3eb76877..55a2fc96 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -157,10 +157,16 @@ def iree_backend_map(device): def replace_with_tk_kernels( flow_dialect_ir, + batch_size ): - kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_16x1024x10240x1280.mlir" - ] + if batch_size == 8: + kernels = [ + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_16x1024x10240x1280.mlir" + ] + if batch_size == 1: + kernels = [ + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir" + ] # Replace all calls to old kernel with new kernel print("Inserting kernels and updating calls to kernels...") @@ -235,7 +241,10 @@ def compile_to_vmfb( flagset_keywords=[], debug=False, add_tk_kernels=False, + batch_size=1, ): + if batch_size != 1 and batch_size != 8: + add_tk_kernels = False flags = [] if mlir_source == "file" and not isinstance(module_str, str): module_str = str(module_str) @@ -393,7 +402,7 @@ def compile_to_vmfb( flow_ir = flatbuffer_blob.decode("utf-8") - flow_ir_tk = replace_with_tk_kernels(flow_ir) + flow_ir_tk = replace_with_tk_kernels(flow_ir, batch_size) module_str = "\n".join(flow_ir_tk) flags.pop() flags.extend(["--compile-from=flow"]) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index acccc391..4ed874e2 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -370,6 +370,7 @@ class CompiledUnet(CompiledModule): attn_spec=attn_spec, flagset_keywords=["punet"] if use_punet else [], add_tk_kernels=add_tk_kernels, + batch_size=batch_size, ) if exit_on_vmfb: exit()