From 5876436064aa6b4a49d9004088550277e663b6e4 Mon Sep 17 00:00:00 2001 From: nithinsubbiah Date: Wed, 24 Jul 2024 21:35:36 -0500 Subject: [PATCH 1/2] [tk kernel] Add support to match kernel with number of arguments and update kernel links --- .../custom_models/sd_inference/utils.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 9d5c149a..2044863d 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -158,11 +158,12 @@ def iree_backend_map(device): def replace_with_tk_kernels(flow_dialect_ir, batch_size): if batch_size == 8: kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_16x1024x10240x1280.mlir" + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs8/tk_gemm_fused_16x1024x10240x1280.mlir" ] if batch_size == 1: kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir" + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x10240x1280.mlir", + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir" ] # Replace all calls to old kernel with new kernel @@ -178,20 +179,24 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size): for line in base: for kernel in kernels: suffix = kernel.split("/")[-1].split(".")[0].split("_")[-1] - bias_explicit = False - if "bias" in suffix: - bias_explicit = True - kernel_args = 3 + int(suffix[4:]) - suffix = kernel.split(".")[0].split("_")[-2] + # Uncomment/rework when a kernel with bias comes in + # bias_explicit = False + # if "bias" in suffix: + # bias_explicit = True + # kernel_args = 3 + int(suffix[4:]) + # suffix = kernel.split(".")[0].split("_")[-2] B, M, N, K = suffix.split("x") old_kernel = f"matmul_like_{B}x{M}x{N}x{K}" if not old_kernel in line: continue if old_kernel in line and "func.func" in line: - if bias_explicit: - num_args = line.count("arg") - if num_args != kernel_args: - continue + data = urlopen(kernel).read().decode("utf-8") + data = data.split("\n") + idx_with_kernel_args = [idx for idx, s in enumerate(data) if 'func.func' in s][0] + kernel_args = data[idx_with_kernel_args].count('arg') + num_args = line.count("arg") + if num_args != kernel_args: + continue kernel_map[kernel] = line.strip().split(" ")[1][1:-7] prefix_map[kernel] = kernel_map[kernel].split(old_kernel)[0][:-1] if ( From 51ef1db277efe120d18a23c1b576ba95f5133e2e Mon Sep 17 00:00:00 2001 From: nithinsubbiah Date: Thu, 25 Jul 2024 00:00:20 -0500 Subject: [PATCH 2/2] Fix formatting --- .../custom_models/sd_inference/utils.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 2044863d..5bee5a09 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -163,7 +163,7 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size): if batch_size == 1: kernels = [ "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x10240x1280.mlir", - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir" + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir", ] # Replace all calls to old kernel with new kernel @@ -192,8 +192,10 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size): if old_kernel in line and "func.func" in line: data = urlopen(kernel).read().decode("utf-8") data = data.split("\n") - idx_with_kernel_args = [idx for idx, s in enumerate(data) if 'func.func' in s][0] - kernel_args = data[idx_with_kernel_args].count('arg') + idx_with_kernel_args = [ + idx for idx, s in enumerate(data) if "func.func" in s + ][0] + kernel_args = data[idx_with_kernel_args].count("arg") num_args = line.count("arg") if num_args != kernel_args: continue @@ -552,11 +554,11 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers["EulerAncestralDiscrete"] = ( - EulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) + schedulers[ + "EulerAncestralDiscrete" + ] = EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", ) # schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained( # model_id,