From 491d8ad78ed28a48c9cd8e742a4b856c44caa815 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 8 Jul 2024 15:55:37 -0500 Subject: [PATCH] Fixes for non-punet attn spec --- .../custom_models/sd_inference/utils.py | 20 ++++++++++++------- .../custom_models/sdxl_inference/unet.py | 6 +++++- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 4673f4f24..9b4e6159b 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -229,11 +229,14 @@ def compile_to_vmfb( # This 'attn_spec' handles a linalg_ext.attention op lowering to mfma instructions for capable targets. # This is a temporary solution, and should be removed or largely disabled once the functionality of # the TD spec is implemented in C++. - if attn_spec in ["default", "mfma", "i8"]: + + if attn_spec in ["default", "mfma", "punet"]: + use_punet = True if attn_spec in ["punet", "i8"] else False attn_spec = get_mfma_spec_path( - target_triple, os.path.dirname(safe_name), masked_attention + target_triple, os.path.dirname(safe_name), masked_attention, use_punet=use_punet ) flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): attn_spec = get_wmma_spec_path( target_triple, os.path.dirname(safe_name), masked_attention @@ -307,15 +310,18 @@ def create_safe_name(hf_model_name, model_name_str=""): return safe_name -def get_mfma_spec_path(target_chip, save_dir, masked_attention=False): - if not masked_attention: +def get_mfma_spec_path(target_chip, save_dir, masked_attention=False, use_punet=False): + if use_punet: + suffix = "_punet" url = "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/specs/attention_and_matmul_spec.mlir" + elif not masked_attention: + suffix = "" + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_mfma.mlir" else: + suffix = "_pad" url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir" attn_spec = urlopen(url).read().decode("utf-8") - spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir") - if os.path.exists(spec_path): - return spec_path + spec_path = os.path.join(save_dir, f"attention_and_matmul_spec_mfma{suffix}.mlir") with open(spec_path, "w") as f: f.write(attn_spec) return spec_path diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index d8b6c41d2..39429e36a 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -193,7 +193,11 @@ def export_unet_model( else: submodel_name = "unet" if (not decomp_attn) and use_punet: - attn_spec = "i8" + attn_spec = "punet" + elif (not decomp_attn) and "gfx9" in target: + attn_spec = "mfma" + elif (not decomp_attn) and "gfx11" in target: + attn_spec = "wmma" safe_name = utils.create_safe_name( hf_model_name, f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}",