Skip to content

Commit

Permalink
Allow mps device but set PYTORCH_ENABLE_MPS_FALLBACK
Browse files Browse the repository at this point in the history
This is needed since the operator 'aten::unfold_backward' is not currently implemented for MPS in pytorch - see pytorch/pytorch#77764
  • Loading branch information
BenRogersNewsome committed Jan 24, 2025
1 parent d9683b6 commit 8f3c6c7
Showing 1 changed file with 33 additions and 25 deletions.
58 changes: 33 additions & 25 deletions noisereduce/spectralgate/streamed_torch_gate.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,38 @@
import os
import torch
from noisereduce.spectralgate.base import SpectralGate
from noisereduce.torchgate import TorchGate as TG
import numpy as np


class StreamedTorchGate(SpectralGate):
'''
"""
Run interface with noisereduce.
'''
"""

def __init__(
self,
y,
sr,
stationary=False,
y_noise=None,
prop_decrease=1.0,
time_constant_s=2.0,
freq_mask_smooth_hz=500,
time_mask_smooth_ms=50,
thresh_n_mult_nonstationary=2,
sigmoid_slope_nonstationary=10,
n_std_thresh_stationary=1.5,
tmp_folder=None,
chunk_size=600000,
padding=30000,
n_fft=1024,
win_length=None,
hop_length=None,
clip_noise_stationary=True,
use_tqdm=False,
n_jobs=1,
device="cuda",
self,
y,
sr,
stationary=False,
y_noise=None,
prop_decrease=1.0,
time_constant_s=2.0,
freq_mask_smooth_hz=500,
time_mask_smooth_ms=50,
thresh_n_mult_nonstationary=2,
sigmoid_slope_nonstationary=10,
n_std_thresh_stationary=1.5,
tmp_folder=None,
chunk_size=600000,
padding=30000,
n_fft=1024,
win_length=None,
hop_length=None,
clip_noise_stationary=True,
use_tqdm=False,
n_jobs=1,
device="cuda",
):
super().__init__(
y=y,
Expand All @@ -50,7 +51,14 @@ def __init__(
n_jobs=n_jobs,
)

self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
elif device == "mps":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
else:
device = "cpu"

self.device = torch.device(device)

# noise convert to torch if needed
if y_noise is not None:
Expand Down

0 comments on commit 8f3c6c7

Please sign in to comment.