diff --git a/models/turbine_models/custom_models/sd_inference/clip.py b/models/turbine_models/custom_models/sd_inference/clip.py index a15426d28..52c36a5c3 100644 --- a/models/turbine_models/custom_models/sd_inference/clip.py +++ b/models/turbine_models/custom_models/sd_inference/clip.py @@ -180,10 +180,8 @@ def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)): if __name__ == "__main__": from .sd_cmd_opts import args - mod_str, _ = export_clip( + mod_str, _ = export_clip_model( args.hf_model_name, - args.hf_auth_token, - args.batch_size, args.max_length, args.precision, args.compile_to, @@ -195,7 +193,9 @@ def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)): exit_on_vmfb=True, pipeline_dir=args.pipeline_dir, input_mlir=args.input_mlir, - attn_spec=args.attn_spec, + td_spec=args.attn_spec, + weights_only=False, + upload_ir=False, ) if args.input_mlir: exit() diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index b0829103f..7af7dcb10 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -80,15 +80,16 @@ class StableDiffusionTest(unittest.TestCase): def testExportT5Model(self): current_args = copy.deepcopy(default_arguments) current_args["hf_model_name"] = "google/t5-v1_1-small" - safe_prefix = "t5_v1_1_small" blob_name = clip.export_clip_model( hf_model_name=current_args["hf_model_name"], - hf_auth_token=None, + max_length=64, + precision=current_args["precision"], compile_to="vmfb", external_weights=None, external_weight_path=None, device="cpu", target_triple=None, + exit_on_vmfb=False, upload_ir=UPLOAD_IR, ) current_args["vmfb_path"] = blob_name @@ -119,12 +120,14 @@ def testExportClipVitLarge14(self): safe_prefix = "clip_vit_large_patch14" blob_name = clip.export_clip_model( hf_model_name=current_args["hf_model_name"], - hf_auth_token=None, + max_length=64, + precision=current_args["precision"], compile_to="vmfb", external_weights="safetensors", external_weight_path=safe_prefix + ".safetensors", device="cpu", target_triple=None, + exit_on_vmfb=False, upload_ir=UPLOAD_IR, ) current_args["external_weight_path"] = safe_prefix + ".safetensors" @@ -156,13 +159,15 @@ def testExportClipModel(self): current_args = copy.deepcopy(default_arguments) current_args["hf_model_name"] = "CompVis/stable-diffusion-v1-4" blob_name = clip.export_clip_model( - # This is a public model, so no auth required - "CompVis/stable-diffusion-v1-4", - None, - "vmfb", - "safetensors", - "stable_diffusion_v1_4_clip.safetensors", - "cpu", + hf_model_name=current_args["hf_model_name"], + max_length=64, + precision=current_args["precision"], + compile_to="vmfb", + external_weights="safetensors", + external_weight_path=safe_prefix + ".safetensors", + device="cpu", + target_triple=None, + exit_on_vmfb=False, upload_ir=UPLOAD_IR, ) current_args["external_weight_path"] = "stable_diffusion_v1_4_clip.safetensors" @@ -194,7 +199,7 @@ def testExportUnetModel(self): current_args = copy.deepcopy(default_arguments) blob_name = unet.export_unet_model( unet_model, - "CompVis/stable-diffusion-v1-4", + current_args["hf_model_name"], current_args["batch_size"], current_args["height"], current_args["width"], @@ -203,12 +208,12 @@ def testExportUnetModel(self): None, "vmfb", "safetensors", - "stable_diffusion_v1_4_unet.safetensors", + "stable_diffusion_unet.safetensors", "cpu", upload_ir=UPLOAD_IR, ) - current_args["external_weight_path"] = "stable_diffusion_v1_4_unet.safetensors" - current_args["vmfb_path"] = "stable_diffusion_v1_4_unet.vmfb" + current_args["external_weight_path"] = "stable_diffusion_unet.safetensors" + current_args["vmfb_path"] = blob_name sample = torch.rand( current_args["batch_size"], current_args["in_channels"], @@ -219,9 +224,13 @@ def testExportUnetModel(self): timestep = torch.zeros(1, dtype=torch.float32) if current_args["hf_model_name"] == "CompVis/stable-diffusion-v1-4": - encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + encoder_hidden_states = torch.rand( + 2, current_args["max_length"], 768, dtype=torch.float32 + ) elif current_args["hf_model_name"] == "stabilityai/stable-diffusion-2-1-base": - encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) + encoder_hidden_states = torch.rand( + 2, current_args["max_length"], 1024, dtype=torch.float32 + ) guidance_scale = torch.tensor( [current_args["guidance_scale"]], dtype=torch.float32 ) @@ -251,8 +260,8 @@ def testExportUnetModel(self): new_blob_name = blob_name.split(".") new_blob_name = new_blob_name[0] + "-pass.mlir" turbine_tank.changeBlobName(blob_name, new_blob_name) - os.remove("stable_diffusion_v1_4_unet.safetensors") - os.remove("stable_diffusion_v1_4_unet.vmfb") + os.remove("stable_diffusion_unet.safetensors") + os.remove(blob_name) del torch_output del turbine @@ -260,12 +269,11 @@ def testExportVaeModelDecode(self): current_args = copy.deepcopy(default_arguments) blob_name = vae.export_vae_model( vae_model, - # This is a public model, so no auth required - "CompVis/stable-diffusion-v1-4", + current_args["hf_model_name"], current_args["batch_size"], current_args["height"], current_args["width"], - None, + current_args["precision"], "vmfb", "safetensors", "stable_diffusion_v1_4_vae.safetensors", @@ -303,14 +311,13 @@ def testExportVaeModelDecode(self): del torch_output del turbine os.remove("stable_diffusion_v1_4_vae.safetensors") - os.remove("stable_diffusion_v1_4_vae.vmfb") + os.remove("blob_name") def testExportVaeModelEncode(self): current_args = copy.deepcopy(default_arguments) blob_name = vae.export_vae_model( vae_model, - # This is a public model, so no auth required - "CompVis/stable-diffusion-v1-4", + current_args["hf_model_name"], current_args["batch_size"], current_args["height"], current_args["width"], @@ -350,7 +357,7 @@ def testExportVaeModelEncode(self): new_blob_name = new_blob_name[0] + "-pass.mlir" turbine_tank.changeBlobName(blob_name, new_blob_name) os.remove("stable_diffusion_v1_4_vae.safetensors") - os.remove("stable_diffusion_v1_4_vae.vmfb") + os.remove(blob_name) @unittest.expectedFailure def testExportPNDMScheduler(self):