diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 9d5c149a..5bee5a09 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -158,11 +158,12 @@ def iree_backend_map(device): def replace_with_tk_kernels(flow_dialect_ir, batch_size): if batch_size == 8: kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_16x1024x10240x1280.mlir" + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs8/tk_gemm_fused_16x1024x10240x1280.mlir" ] if batch_size == 1: kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir" + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x10240x1280.mlir", + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir", ] # Replace all calls to old kernel with new kernel @@ -178,20 +179,26 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size): for line in base: for kernel in kernels: suffix = kernel.split("/")[-1].split(".")[0].split("_")[-1] - bias_explicit = False - if "bias" in suffix: - bias_explicit = True - kernel_args = 3 + int(suffix[4:]) - suffix = kernel.split(".")[0].split("_")[-2] + # Uncomment/rework when a kernel with bias comes in + # bias_explicit = False + # if "bias" in suffix: + # bias_explicit = True + # kernel_args = 3 + int(suffix[4:]) + # suffix = kernel.split(".")[0].split("_")[-2] B, M, N, K = suffix.split("x") old_kernel = f"matmul_like_{B}x{M}x{N}x{K}" if not old_kernel in line: continue if old_kernel in line and "func.func" in line: - if bias_explicit: - num_args = line.count("arg") - if num_args != kernel_args: - continue + data = urlopen(kernel).read().decode("utf-8") + data = data.split("\n") + idx_with_kernel_args = [ + idx for idx, s in enumerate(data) if "func.func" in s + ][0] + kernel_args = data[idx_with_kernel_args].count("arg") + num_args = line.count("arg") + if num_args != kernel_args: + continue kernel_map[kernel] = line.strip().split(" ")[1][1:-7] prefix_map[kernel] = kernel_map[kernel].split(old_kernel)[0][:-1] if ( @@ -547,11 +554,11 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers["EulerAncestralDiscrete"] = ( - EulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) + schedulers[ + "EulerAncestralDiscrete" + ] = EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", ) # schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained( # model_id,