Skip to content

Commit

Permalink
fixes to sd tests
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed May 29, 2024
1 parent b65b577 commit d1f5765
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 29 deletions.
8 changes: 4 additions & 4 deletions models/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,8 @@ def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)):
if __name__ == "__main__":
from .sd_cmd_opts import args

mod_str, _ = export_clip(
mod_str, _ = export_clip_model(
args.hf_model_name,
args.hf_auth_token,
args.batch_size,
args.max_length,
args.precision,
args.compile_to,
Expand All @@ -195,7 +193,9 @@ def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)):
exit_on_vmfb=True,
pipeline_dir=args.pipeline_dir,
input_mlir=args.input_mlir,
attn_spec=args.attn_spec,
td_spec=args.attn_spec,
weights_only=False,
upload_ir=False,
)
if args.input_mlir:
exit()
Expand Down
57 changes: 32 additions & 25 deletions models/turbine_models/tests/sd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,16 @@ class StableDiffusionTest(unittest.TestCase):
def testExportT5Model(self):
current_args = copy.deepcopy(default_arguments)
current_args["hf_model_name"] = "google/t5-v1_1-small"
safe_prefix = "t5_v1_1_small"
blob_name = clip.export_clip_model(
hf_model_name=current_args["hf_model_name"],
hf_auth_token=None,
max_length=64,
precision=current_args["precision"],
compile_to="vmfb",
external_weights=None,
external_weight_path=None,
device="cpu",
target_triple=None,
exit_on_vmfb=False,
upload_ir=UPLOAD_IR,
)
current_args["vmfb_path"] = blob_name
Expand Down Expand Up @@ -119,12 +120,14 @@ def testExportClipVitLarge14(self):
safe_prefix = "clip_vit_large_patch14"
blob_name = clip.export_clip_model(
hf_model_name=current_args["hf_model_name"],
hf_auth_token=None,
max_length=64,
precision=current_args["precision"],
compile_to="vmfb",
external_weights="safetensors",
external_weight_path=safe_prefix + ".safetensors",
device="cpu",
target_triple=None,
exit_on_vmfb=False,
upload_ir=UPLOAD_IR,
)
current_args["external_weight_path"] = safe_prefix + ".safetensors"
Expand Down Expand Up @@ -156,13 +159,15 @@ def testExportClipModel(self):
current_args = copy.deepcopy(default_arguments)
current_args["hf_model_name"] = "CompVis/stable-diffusion-v1-4"
blob_name = clip.export_clip_model(
# This is a public model, so no auth required
"CompVis/stable-diffusion-v1-4",
None,
"vmfb",
"safetensors",
"stable_diffusion_v1_4_clip.safetensors",
"cpu",
hf_model_name=current_args["hf_model_name"],
max_length=64,
precision=current_args["precision"],
compile_to="vmfb",
external_weights="safetensors",
external_weight_path=safe_prefix + ".safetensors",
device="cpu",
target_triple=None,
exit_on_vmfb=False,
upload_ir=UPLOAD_IR,
)
current_args["external_weight_path"] = "stable_diffusion_v1_4_clip.safetensors"
Expand Down Expand Up @@ -194,7 +199,7 @@ def testExportUnetModel(self):
current_args = copy.deepcopy(default_arguments)
blob_name = unet.export_unet_model(
unet_model,
"CompVis/stable-diffusion-v1-4",
current_args["hf_model_name"],
current_args["batch_size"],
current_args["height"],
current_args["width"],
Expand All @@ -203,12 +208,12 @@ def testExportUnetModel(self):
None,
"vmfb",
"safetensors",
"stable_diffusion_v1_4_unet.safetensors",
"stable_diffusion_unet.safetensors",
"cpu",
upload_ir=UPLOAD_IR,
)
current_args["external_weight_path"] = "stable_diffusion_v1_4_unet.safetensors"
current_args["vmfb_path"] = "stable_diffusion_v1_4_unet.vmfb"
current_args["external_weight_path"] = "stable_diffusion_unet.safetensors"
current_args["vmfb_path"] = blob_name
sample = torch.rand(
current_args["batch_size"],
current_args["in_channels"],
Expand All @@ -219,9 +224,13 @@ def testExportUnetModel(self):

timestep = torch.zeros(1, dtype=torch.float32)
if current_args["hf_model_name"] == "CompVis/stable-diffusion-v1-4":
encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32)
encoder_hidden_states = torch.rand(
2, current_args["max_length"], 768, dtype=torch.float32
)
elif current_args["hf_model_name"] == "stabilityai/stable-diffusion-2-1-base":
encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32)
encoder_hidden_states = torch.rand(
2, current_args["max_length"], 1024, dtype=torch.float32
)
guidance_scale = torch.tensor(
[current_args["guidance_scale"]], dtype=torch.float32
)
Expand Down Expand Up @@ -251,21 +260,20 @@ def testExportUnetModel(self):
new_blob_name = blob_name.split(".")
new_blob_name = new_blob_name[0] + "-pass.mlir"
turbine_tank.changeBlobName(blob_name, new_blob_name)
os.remove("stable_diffusion_v1_4_unet.safetensors")
os.remove("stable_diffusion_v1_4_unet.vmfb")
os.remove("stable_diffusion_unet.safetensors")
os.remove(blob_name)
del torch_output
del turbine

def testExportVaeModelDecode(self):
current_args = copy.deepcopy(default_arguments)
blob_name = vae.export_vae_model(
vae_model,
# This is a public model, so no auth required
"CompVis/stable-diffusion-v1-4",
current_args["hf_model_name"],
current_args["batch_size"],
current_args["height"],
current_args["width"],
None,
current_args["precision"],
"vmfb",
"safetensors",
"stable_diffusion_v1_4_vae.safetensors",
Expand Down Expand Up @@ -303,14 +311,13 @@ def testExportVaeModelDecode(self):
del torch_output
del turbine
os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4_vae.vmfb")
os.remove("blob_name")

def testExportVaeModelEncode(self):
current_args = copy.deepcopy(default_arguments)
blob_name = vae.export_vae_model(
vae_model,
# This is a public model, so no auth required
"CompVis/stable-diffusion-v1-4",
current_args["hf_model_name"],
current_args["batch_size"],
current_args["height"],
current_args["width"],
Expand Down Expand Up @@ -350,7 +357,7 @@ def testExportVaeModelEncode(self):
new_blob_name = new_blob_name[0] + "-pass.mlir"
turbine_tank.changeBlobName(blob_name, new_blob_name)
os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4_vae.vmfb")
os.remove(blob_name)

@unittest.expectedFailure
def testExportPNDMScheduler(self):
Expand Down

0 comments on commit d1f5765

Please sign in to comment.