diff --git a/noisereduce/spectralgate/streamed_torch_gate.py b/noisereduce/spectralgate/streamed_torch_gate.py index 3a990aa..6176413 100644 --- a/noisereduce/spectralgate/streamed_torch_gate.py +++ b/noisereduce/spectralgate/streamed_torch_gate.py @@ -1,3 +1,4 @@ +import os import torch from noisereduce.spectralgate.base import SpectralGate from noisereduce.torchgate import TorchGate as TG @@ -50,7 +51,12 @@ def __init__( n_jobs=n_jobs, ) - self.device = torch.device(device if torch.cuda.is_available() else 'cpu') + if "cuda" in device and not torch.cuda.is_available(): + device = "cpu" + elif device == "mps": + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + + self.device = torch.device(device) # noise convert to torch if needed if y_noise is not None: