From 7cc1be5e6fb17f0aae4c9187390c714158015184 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 25 Jun 2024 20:39:13 -0500 Subject: [PATCH] (SD) Cleanup after each test. (#744) Segfaults otherwise on certain runners. --- models/turbine_models/tests/sd_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index 76c11bcba..9d379f571 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -112,6 +112,7 @@ def testExportT5Model(self): new_blob_name = new_blob_name[0] + "-pass.mlir" turbine_tank.changeBlobName(blob_name, new_blob_name) del current_args + del turbine def testExportClipVitLarge14(self): current_args = copy.deepcopy(default_arguments) @@ -152,6 +153,8 @@ def testExportClipVitLarge14(self): if platform.system() != "Windows": os.remove(current_args["external_weight_path"]) os.remove(current_args["vmfb_path"]) + del current_args + del turbine def testExportClipModel(self): current_args = copy.deepcopy(default_arguments) @@ -190,7 +193,10 @@ def testExportClipModel(self): if platform.system() != "Windows": os.remove(current_args["external_weight_path"]) os.remove(current_args["vmfb_path"]) + del current_args + del turbine + @unittest.expectedFailure def testExportUnetModel(self): current_args = copy.deepcopy(default_arguments) blob_name = unet.export_unet_model( @@ -301,6 +307,7 @@ def testExportVaeModelDecode(self): new_blob_name = blob_name.split(".") new_blob_name = new_blob_name[0] + "-pass.mlir" turbine_tank.changeBlobName(blob_name, new_blob_name) + del current_args del torch_output del turbine os.remove("stable_diffusion_v1_4_vae.safetensors") @@ -352,6 +359,8 @@ def testExportVaeModelEncode(self): turbine_tank.changeBlobName(blob_name, new_blob_name) os.remove("stable_diffusion_v1_4_vae.safetensors") os.remove("stable_diffusion_v1_4_vae.vmfb") + del current_args + del turbine @unittest.expectedFailure def testExportPNDMScheduler(self): @@ -405,6 +414,7 @@ def testExportPNDMScheduler(self): turbine_tank.changeBlobName(blob_name, new_blob_name) os.remove("stable_diffusion_v1_4_scheduler.safetensors") os.remove("stable_diffusion_v1_4_scheduler.vmfb") + del current_args del torch_output del turbine