From df1002e5bf3ea8f303bd284e9a4717d9ac7c5c02 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Apr 2024 13:05:41 -0500 Subject: [PATCH] Pipe through attn spec option correctly. --- .github/workflows/test_models.yml | 2 +- .../custom_models/sdxl_inference/sdxl_prompt_encoder.py | 9 +++++++-- .../custom_models/sdxl_inference/sdxl_scheduled_unet.py | 5 ++--- .../turbine_models/custom_models/sdxl_inference/unet.py | 4 ++-- .../turbine_models/custom_models/sdxl_inference/vae.py | 9 +++++++-- models/turbine_models/tests/conftest.py | 1 + models/turbine_models/tests/sdxl_test.py | 3 ++- 7 files changed, 22 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 1b51aaee1..75f2f65d9 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -62,4 +62,4 @@ jobs: pytest models/turbine_models/tests/sd_test.py pytest models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --decomp_attn True pytest models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux - pytest models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --decomp_attn True --attn_spec None + pytest models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --decomp_attn True diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index dc53d6bbe..4910bd827 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -160,11 +160,16 @@ def export_prompt_encoder( else: do_classifier_free_guidance = True - if attn_spec in ["default", "", None] and ("gfx9" in target_triple): + if ( + (attn_spec in ["default", None]) + and ("gfx94" in target_triple) + ): attn_spec = os.path.join( os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir" ) - + else: + attn_spec = None + if pipeline_dir not in [None, ""]: safe_name = os.path.join(pipeline_dir, "prompt_encoder") else: diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py index 2564a7b8c..42176f2de 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py @@ -139,10 +139,9 @@ def export_scheduled_unet_model( do_classifier_free_guidance = False else: do_classifier_free_guidance = True - if ( - (attn_spec in ["default", "", None]) - and (decomp_attn is not None) + (attn_spec in ["default", None]) + and decomp_attn == False and ("gfx9" in iree_target_triple) ): attn_spec = os.path.join( diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 83e60f9e9..6490ff00b 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -106,8 +106,8 @@ def export_unet_model( do_classifier_free_guidance = True if ( - (attn_spec in ["default", "", None]) - and (decomp_attn is not None) + (attn_spec in ["default", None]) + and decomp_attn == False and ("gfx9" in target_triple) ): attn_spec = os.path.join( diff --git a/models/turbine_models/custom_models/sdxl_inference/vae.py b/models/turbine_models/custom_models/sdxl_inference/vae.py index a1aaa235c..18cd0e53d 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae.py @@ -84,12 +84,17 @@ def export_vae_model( input_mlir=None, weights_only=False, ): - if attn_spec in ["default", "", None] and ("gfx9" in target_triple): + if ( + (attn_spec in ["default", None]) + and decomp_attn == False + and ("gfx9" in target_triple) + ): attn_spec = os.path.join( os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir" ) - if decomp_attn: + elif decomp_attn: attn_spec = None + if pipeline_dir: safe_name = os.path.join(pipeline_dir, "vae_" + variant) else: diff --git a/models/turbine_models/tests/conftest.py b/models/turbine_models/tests/conftest.py index b287b1924..7a1f55b1a 100644 --- a/models/turbine_models/tests/conftest.py +++ b/models/turbine_models/tests/conftest.py @@ -37,6 +37,7 @@ def pytest_addoption(parser): parser.addoption("--compile_to", action="store", default=None) parser.addoption("--external_weights", action="store", default="safetensors") parser.addoption("--decomp_attn", action="store", default=True) + parser.addoption("--attn_spec", action="store", default="") # Compiler Options parser.addoption("--device", action="store", default="cpu") parser.addoption("--rt_device", action="store", default="local-task") diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 3b351c6f3..362b86fb2 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -61,6 +61,7 @@ def command_line_args(request): arguments["compile_to"] = request.config.getoption("--compile_to") arguments["external_weights"] = request.config.getoption("--external_weights") arguments["decomp_attn"] = request.config.getoption("--decomp_attn") + arguments["attn_spec"] = request.config.getoption("--attn_spec") arguments["device"] = request.config.getoption("--device") arguments["rt_device"] = request.config.getoption("--rt_device") arguments["iree_target_triple"] = request.config.getoption("--iree_target_triple") @@ -561,7 +562,7 @@ def test05_t2i_generate_images(self): arguments["device"], arguments["iree_target_triple"], ireec_flags, - None, # attn_spec + arguments["attn_spec"], arguments["decomp_attn"], arguments["pipeline_dir"], external_weights_dir,