Skip to content

Commit

Permalink
Pipe through attn spec option correctly.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Apr 9, 2024
1 parent 946a02f commit 77d4308
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ jobs:
pytest models/turbine_models/tests/sd_test.py
pytest models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --decomp_attn True
pytest models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux
pytest models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --decomp_attn True --attn_spec None
pytest models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --decomp_attn True
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,12 @@ def export_prompt_encoder(
else:
do_classifier_free_guidance = True

if attn_spec in ["default", "", None] and ("gfx9" in target_triple):
if (attn_spec in ["default", None]) and ("gfx94" in target_triple):
attn_spec = os.path.join(
os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir"
)
else:
attn_spec = None

if pipeline_dir not in [None, ""]:
safe_name = os.path.join(pipeline_dir, "prompt_encoder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,9 @@ def export_scheduled_unet_model(
do_classifier_free_guidance = False
else:
do_classifier_free_guidance = True

if (
(attn_spec in ["default", "", None])
and (decomp_attn is not None)
(attn_spec in ["default", None])
and decomp_attn == False
and ("gfx9" in iree_target_triple)
):
attn_spec = os.path.join(
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/sdxl_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def export_unet_model(
do_classifier_free_guidance = True

if (
(attn_spec in ["default", "", None])
and (decomp_attn is not None)
(attn_spec in ["default", None])
and decomp_attn == False
and ("gfx9" in target_triple)
):
attn_spec = os.path.join(
Expand Down
9 changes: 7 additions & 2 deletions models/turbine_models/custom_models/sdxl_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,17 @@ def export_vae_model(
input_mlir=None,
weights_only=False,
):
if attn_spec in ["default", "", None] and ("gfx9" in target_triple):
if (
(attn_spec in ["default", None])
and decomp_attn == False
and ("gfx9" in target_triple)
):
attn_spec = os.path.join(
os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir"
)
if decomp_attn:
elif decomp_attn:
attn_spec = None

if pipeline_dir:
safe_name = os.path.join(pipeline_dir, "vae_" + variant)
else:
Expand Down
1 change: 1 addition & 0 deletions models/turbine_models/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def pytest_addoption(parser):
parser.addoption("--compile_to", action="store", default=None)
parser.addoption("--external_weights", action="store", default="safetensors")
parser.addoption("--decomp_attn", action="store", default=True)
parser.addoption("--attn_spec", action="store", default="")
# Compiler Options
parser.addoption("--device", action="store", default="cpu")
parser.addoption("--rt_device", action="store", default="local-task")
Expand Down
3 changes: 2 additions & 1 deletion models/turbine_models/tests/sdxl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def command_line_args(request):
arguments["compile_to"] = request.config.getoption("--compile_to")
arguments["external_weights"] = request.config.getoption("--external_weights")
arguments["decomp_attn"] = request.config.getoption("--decomp_attn")
arguments["attn_spec"] = request.config.getoption("--attn_spec")
arguments["device"] = request.config.getoption("--device")
arguments["rt_device"] = request.config.getoption("--rt_device")
arguments["iree_target_triple"] = request.config.getoption("--iree_target_triple")
Expand Down Expand Up @@ -561,7 +562,7 @@ def test05_t2i_generate_images(self):
arguments["device"],
arguments["iree_target_triple"],
ireec_flags,
None, # attn_spec
arguments["attn_spec"],
arguments["decomp_attn"],
arguments["pipeline_dir"],
external_weights_dir,
Expand Down

0 comments on commit 77d4308

Please sign in to comment.