diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 97beae035..387eabc87 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -93,7 +93,7 @@ def setUp(self): ) def test01_ExportClipModels(self): - if arguments["device"] in ["vulkan", "rocm", "cuda"]: + if arguments["device"] in ["vulkan", "cuda"]: self.skipTest( "Compilation error on vulkan; Runtime error on rocm; To be tested on cuda." ) @@ -215,7 +215,7 @@ def test01_ExportClipModels(self): np.testing.assert_allclose(torch_output_2, turbine_2[0], rtol, atol) def test02_ExportUnetModel(self): - if arguments["device"] in ["vulkan", "rocm", "cuda"]: + if arguments["device"] in ["vulkan", "cuda"]: self.skipTest( "Unknown error on vulkan; Runtime error on rocm; To be tested on cuda." ) @@ -325,7 +325,7 @@ def test02_ExportUnetModel(self): np.testing.assert_allclose(torch_output, turbine, rtol, atol) def test03_ExportVaeModelDecode(self): - if arguments["device"] in ["vulkan", "cuda", "rocm"]: + if arguments["device"] in ["vulkan", "cuda"]: self.skipTest( "Compilation error on vulkan; Runtime error on rocm; To be tested on cuda." )