Skip to content

Commit

Permalink
typo fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 4, 2024
1 parent f39b2d2 commit d3c8e80
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions models/turbine_models/tests/sdxl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test00_sdxl_pipe(self):
)
assert output is not None

def test01_sdxl_pipe_i8(self):
def test01_sdxl_pipe_i8_punet(self):
from turbine_models.custom_models.sd_inference.sd_pipeline import (
SharkSDPipeline,
)
Expand Down Expand Up @@ -192,24 +192,24 @@ def test02_PromptEncoder(self):
if arguments["device"] in ["vulkan", "cuda"]:
self.skipTest("Compilation error on vulkan; To be tested on cuda.")
clip_filename = (
"_".join(
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"bs" + str(arguments["batch_size"]),
str(arguments["max_length"]),
arguments["precision"],
"text_encoder",
arguments["device"],
arguments["iree_target_triple"],
)
])
+ ".vmfb"
)
arguments["vmfb_path"] = os.path.join("test_vmfbs", clip_filename)
clip_w_filename = (
"_".join(
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"text_encoder",
arguments["precision"],
)
])
+ ".safetensors"
)
arguments["external_weight_path"] = os.path.join(
Expand Down Expand Up @@ -241,7 +241,7 @@ def test02_PromptEncoder(self):
turbine_output2,
) = sdxl_prompt_encoder_runner.run_prompt_encoder(
arguments["vmfb_path"],
arguments["rt_driver"],
arguments["rt_device"],
arguments["external_weight_path"],
text_input_ids_list,
uncond_input_ids_list,
Expand All @@ -259,7 +259,7 @@ def test02_PromptEncoder(self):
"prompt_encoder",
arguments["vmfb_path"],
arguments["external_weight_path"],
arguments["rt_driver"],
arguments["rt_device"],
max_length=arguments["max_length"],
tracy_profile=arguments["tracy_profile"],
)
Expand All @@ -272,7 +272,7 @@ def test03_unet(self):
if arguments["device"] in ["vulkan", "cuda"]:
self.skipTest("Unknown error on vulkan; To be tested on cuda.")
unet_filename = (
"_".join(
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"bs" + str(arguments["batch_size"]),
str(arguments["max_length"]),
Expand All @@ -281,16 +281,16 @@ def test03_unet(self):
"unet",
arguments["device"],
arguments["iree_target_triple"],
)
])
+ ".vmfb"
)
arguments["vmfb_path"] = os.path.join("test_vmfbs", unet_filename)
unet_w_filename = (
"_".join(
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"unet",
arguments["precision"],
)
])
+ ".safetensors"
)
arguments["external_weight_path"] = os.path.join(
Expand Down Expand Up @@ -363,24 +363,24 @@ def test04_ExportVaeModelDecode(self):
self.skipTest("Compilation error on vulkan; To be tested on cuda.")

vae_filename = (
"_".join(
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"bs" + str(arguments["batch_size"]),
str(arguments["height"]) + "x" + str(arguments["width"]),
arguments["precision"],
"vae",
arguments["device"],
arguments["iree_target_triple"],
)
])
+ ".vmfb"
)
arguments["vmfb_path"] = os.path.join("test_vmfbs", vae_filename)
vae_w_filename = (
"_".join(
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"vae",
arguments["precision"],
)
])
+ ".safetensors"
)
arguments["external_weight_path"] = os.path.join(
Expand Down

0 comments on commit d3c8e80

Please sign in to comment.