From a8221c51dac6a232c0c88b55e81a4b87e57951c0 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 15 Dec 2023 18:34:46 +0530 Subject: [PATCH] Compile test fix (#6104) * update * update --- src/diffusers/utils/testing_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 14b89c3cd3b9b..606980f8a3c5c 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -820,7 +820,9 @@ def _is_torch_fp16_available(device): try: x = torch.zeros((2, 2), dtype=torch.float16).to(device) - _ = x @ x + _ = torch.mul(x, x) + return True + except Exception as e: if device.type == "cuda": raise ValueError( @@ -838,7 +840,9 @@ def _is_torch_fp64_available(device): try: x = torch.zeros((2, 2), dtype=torch.float64).to(device) - _ = x @ x + _ = torch.mul(x, x) + return True + except Exception as e: if device.type == "cuda": raise ValueError(