Skip to content

Commit

Permalink
Fixes for non-punet attn spec
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jul 8, 2024
1 parent de6bf7b commit 491d8ad
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
20 changes: 13 additions & 7 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,14 @@ def compile_to_vmfb(
# This 'attn_spec' handles a linalg_ext.attention op lowering to mfma instructions for capable targets.
# This is a temporary solution, and should be removed or largely disabled once the functionality of
# the TD spec is implemented in C++.
if attn_spec in ["default", "mfma", "i8"]:

if attn_spec in ["default", "mfma", "punet"]:
use_punet = True if attn_spec in ["punet", "i8"] else False
attn_spec = get_mfma_spec_path(
target_triple, os.path.dirname(safe_name), masked_attention
target_triple, os.path.dirname(safe_name), masked_attention, use_punet=use_punet
)
flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec])

elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec):
attn_spec = get_wmma_spec_path(
target_triple, os.path.dirname(safe_name), masked_attention
Expand Down Expand Up @@ -307,15 +310,18 @@ def create_safe_name(hf_model_name, model_name_str=""):
return safe_name


def get_mfma_spec_path(target_chip, save_dir, masked_attention=False):
if not masked_attention:
def get_mfma_spec_path(target_chip, save_dir, masked_attention=False, use_punet=False):
if use_punet:
suffix = "_punet"
url = "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/specs/attention_and_matmul_spec.mlir"
elif not masked_attention:
suffix = ""
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_mfma.mlir"
else:
suffix = "_pad"
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
if os.path.exists(spec_path):
return spec_path
spec_path = os.path.join(save_dir, f"attention_and_matmul_spec_mfma{suffix}.mlir")
with open(spec_path, "w") as f:
f.write(attn_spec)
return spec_path
Expand Down
6 changes: 5 additions & 1 deletion models/turbine_models/custom_models/sdxl_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def export_unet_model(
else:
submodel_name = "unet"
if (not decomp_attn) and use_punet:
attn_spec = "i8"
attn_spec = "punet"
elif (not decomp_attn) and "gfx9" in target:
attn_spec = "mfma"
elif (not decomp_attn) and "gfx11" in target:
attn_spec = "wmma"
safe_name = utils.create_safe_name(
hf_model_name,
f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}",
Expand Down

0 comments on commit 491d8ad

Please sign in to comment.