Skip to content

Commit

Permalink
Integrate int8 tk kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinsubbiah committed Jul 22, 2024
1 parent 2e9de46 commit 7649312
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 7 deletions.
107 changes: 100 additions & 7 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,70 @@ def iree_backend_map(device):
return iree_device


def replace_with_tk_kernels(
flow_dialect_ir,
):
kernels = [
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/tk_int8/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir"
]

# Replace all calls to old kernel with new kernel
print("Inserting kernels and updating calls to kernels...")
kernel_name = {}
for kernel in kernels:
kernel_name[kernel] = kernel.split("/")[-1].split(".")[0]
kernel_map = {}
prefix_map = {}

base = flow_dialect_ir.split("\n")
new_base = []
for line in base:
for kernel in kernels:
suffix = kernel.split("/")[-1].split(".")[0].split("_")[-1]
bias_explicit = False
if "bias" in suffix:
bias_explicit = True
kernel_args = 3 + int(suffix[4:])
suffix = kernel.split(".")[0].split("_")[-2]
B, M, N, K = suffix.split("x")
old_kernel = f"matmul_transpose_b_{B}x{M}x{N}x{K}"
if not old_kernel in line:
continue
if old_kernel in line and "func.func" in line:
if bias_explicit:
num_args = line.count("arg")
if num_args != kernel_args:
continue
kernel_map[kernel] = line.strip().split(" ")[1][1:-7]
prefix_map[kernel] = kernel_map[kernel].split(old_kernel)[0][:-7]
if (
old_kernel in line
and "flow.dispatch" in line
and not "func.func" in line
):
line = line.replace(kernel_map[kernel], kernel_name[kernel])
line = line.replace(prefix_map[kernel], kernel_name[kernel])
new_base.append(line)
# Insert kernels in appropriate locations
final_ir = []
for line in new_base:
for kernel in kernels:
if (
prefix_map[kernel] + " {" in line
and "flow.executable" in line
and "private" in line
):
data = urlopen(kernel).read().decode("utf-8")
data = data.split("\n")
translation_info = data[0].split("#translation = ")[1].strip()
data[10] = data[10].replace("#translation", translation_info)
final_ir.append("\n".join(data[2:-3]))
final_ir.append(line)

print("tk kernels added")
return final_ir


def compile_to_vmfb(
module_str,
device,
Expand All @@ -161,6 +225,7 @@ def compile_to_vmfb(
winograd=False,
flagset_keywords=[],
debug=False,
add_tk_kernels=False,
):
flags = []
if mlir_source == "file" and not isinstance(module_str, str):
Expand Down Expand Up @@ -296,6 +361,34 @@ def compile_to_vmfb(
for idx, flag in enumerate(flags):
if flag is None:
flags.pop(idx)
input_ir_type = "torch"
if add_tk_kernels:
print("Adding tk kernels")
flags.extend(["--compile-to=flow"])
if mlir_source == "file":
flatbuffer_blob = ireec.compile_file(
module_str,
target_backends=[device],
input_type=input_ir_type,
extra_args=flags,
)
elif mlir_source == "str":
flatbuffer_blob = ireec.compile_str(
module_str,
target_backends=[device],
input_type=input_ir_type,
extra_args=flags,
)

flow_ir = flatbuffer_blob.decode("utf-8")

flow_ir_tk = replace_with_tk_kernels(flow_ir)
module_str = "\n".join(flow_ir_tk)
flags.pop()
flags.extend(["--compile-from=flow"])
mlir_source = "str"
input_ir_type = "auto"

print("Compiling to", device, "with flags:", flags)

# Forces a standard for naming files:
Expand All @@ -312,7 +405,7 @@ def compile_to_vmfb(
flatbuffer_blob = ireec.compile_file(
module_str,
target_backends=[device],
input_type="torch",
input_type=input_ir_type,
extra_args=flags,
)
elif mlir_source == "str":
Expand All @@ -323,7 +416,7 @@ def compile_to_vmfb(
flatbuffer_blob = ireec.compile_str(
module_str,
target_backends=[device],
input_type="torch",
input_type=input_ir_type,
extra_args=flags,
)
else:
Expand Down Expand Up @@ -431,11 +524,11 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["EulerAncestralDiscrete"] = (
EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
# schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained(
# model_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,5 +369,11 @@ def is_valid_file(arg):
help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py",
)

p.add_argument(
"--add_tk_kernels",
type=bool,
default=False,
help="Flag to add compiled tk kernels.",
)

args, unknown = p.parse_known_args()
8 changes: 8 additions & 0 deletions models/turbine_models/custom_models/sdxl_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def export_unet_model(
weights_only=False,
use_punet=False,
quant_paths=None,
add_tk_kernels=False,
):
if use_punet:
submodel_name = "punet"
Expand All @@ -209,6 +210,10 @@ def export_unet_model(
if decomp_attn == True:
ireec_flags += ",--iree-opt-aggressively-propagate-transposes=False"

# Currently, only int8 tk kernels are integrated
if add_tk_kernels and precision != "i8":
add_tk_kernels = False

if input_mlir:
vmfb_path = utils.compile_to_vmfb(
input_mlir,
Expand All @@ -220,6 +225,7 @@ def export_unet_model(
return_path=not exit_on_vmfb,
attn_spec=attn_spec,
flagset_keywords=["punet"] if use_punet else [],
add_tk_kernels=add_tk_kernels,
)
return vmfb_path
elif use_punet:
Expand Down Expand Up @@ -355,6 +361,7 @@ class CompiledUnet(CompiledModule):
return_path=True,
attn_spec=attn_spec,
flagset_keywords=["punet"] if use_punet else [],
add_tk_kernels=add_tk_kernels,
)
if exit_on_vmfb:
exit()
Expand Down Expand Up @@ -393,6 +400,7 @@ class CompiledUnet(CompiledModule):
args.decomp_attn,
attn_spec=args.attn_spec,
input_mlir=args.input_mlir,
add_tk_kernels=args.add_tk_kernels,
)
if args.input_mlir:
exit()
Expand Down

0 comments on commit 7649312

Please sign in to comment.