Skip to content

Commit

Permalink
Allow mps device but set PYTORCH_ENABLE_MPS_FALLBACK
Browse files Browse the repository at this point in the history
This is needed since the operator 'aten::unfold_backward' is not currently implemented for MPS in pytorch - see pytorch/pytorch#77764
  • Loading branch information
BenRogersNewsome committed Jan 24, 2025
1 parent 6375a6a commit 73e7028
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion noisereduce/spectralgate/streamed_torch_gate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
from noisereduce.spectralgate.base import SpectralGate
from noisereduce.torchgate import TorchGate as TG
Expand Down Expand Up @@ -50,7 +51,12 @@ def __init__(
n_jobs=n_jobs,
)

self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
if "cuda" in device and not torch.cuda.is_available():
device = "cpu"
elif device == "mps":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

self.device = torch.device(device)

# noise convert to torch if needed
if y_noise is not None:
Expand Down

0 comments on commit 73e7028

Please sign in to comment.