diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 14b89c3cd3b9b..606980f8a3c5c 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -820,7 +820,9 @@ def _is_torch_fp16_available(device): try: x = torch.zeros((2, 2), dtype=torch.float16).to(device) - _ = x @ x + _ = torch.mul(x, x) + return True + except Exception as e: if device.type == "cuda": raise ValueError( @@ -838,7 +840,9 @@ def _is_torch_fp64_available(device): try: x = torch.zeros((2, 2), dtype=torch.float64).to(device) - _ = x @ x + _ = torch.mul(x, x) + return True + except Exception as e: if device.type == "cuda": raise ValueError(