Skip to content

Commit

Permalink
Decompose VAE for cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 4, 2024
1 parent e630d39 commit fc6d018
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions models/turbine_models/tests/sdxl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ def test00_sdxl_pipe(self):
decomp_attn = {
"text_encoder": True,
"unet": False,
"vae": False,
"vae": (
False
if any(x in arguments["device"] for x in ["hip", "rocm"])
else True
),
}
self.pipe = SharkSDPipeline(
arguments["hf_model_name"],
Expand Down Expand Up @@ -377,7 +381,7 @@ def test04_ExportVaeModelDecode(self):
"bs" + str(arguments["batch_size"]),
str(arguments["height"]) + "x" + str(arguments["width"]),
arguments["precision"],
"vae",
"vae" if arguments["device"] != "cpu" else "vae_decomp_attn",
arguments["iree_target_triple"],
]
)
Expand Down

0 comments on commit fc6d018

Please sign in to comment.