Skip to content

Commit

Permalink
scipy.fft, separation acceleration. fixes mir-evaluation#373
Browse files Browse the repository at this point in the history
  • Loading branch information
bmcfee committed May 14, 2024
1 parent 485a425 commit a658f4e
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions mir_eval/separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"""

import numpy as np
import scipy.fftpack
import scipy.fft
from scipy.linalg import toeplitz
from scipy.signal import fftconvolve
import collections
Expand Down Expand Up @@ -692,17 +692,17 @@ def _project(reference_sources, estimated_source, flen):

# computing coefficients of least squares problem via FFT ##
# zero padding and FFT of input data
reference_sources = np.hstack((reference_sources, np.zeros((nsrc, flen - 1))))
estimated_source = np.hstack((estimated_source, np.zeros(flen - 1)))
n_fft = int(2 ** np.ceil(np.log2(nsampl + flen - 1.0)))
sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1)
sef = scipy.fftpack.fft(estimated_source, n=n_fft)
#reference_sources = np.hstack((reference_sources, np.zeros((nsrc, flen - 1))))
#estimated_source = np.hstack((estimated_source, np.zeros(flen - 1)))
n_fft = scipy.fft.next_fast_len(nsampl + flen - 1, real=True)
sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=1)
sef = scipy.fft.rfft(estimated_source, n=n_fft)
# inner products between delayed versions of reference_sources
G = np.zeros((nsrc * flen, nsrc * flen))
for i in range(nsrc):
for j in range(nsrc):
ssf = sf[i] * np.conj(sf[j])
ssf = np.real(scipy.fftpack.ifft(ssf))
ssf = scipy.fft.irfft(ssf)
ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), r=ssf[:flen])
G[i * flen : (i + 1) * flen, j * flen : (j + 1) * flen] = ss
G[j * flen : (j + 1) * flen, i * flen : (i + 1) * flen] = ss.T
Expand All @@ -711,7 +711,7 @@ def _project(reference_sources, estimated_source, flen):
D = np.zeros(nsrc * flen)
for i in range(nsrc):
ssef = sf[i] * np.conj(sef)
ssef = np.real(scipy.fftpack.ifft(ssef))
ssef = scipy.fft.irfft(ssef)
D[i * flen : (i + 1) * flen] = np.hstack((ssef[0], ssef[-1:-flen:-1]))

# Computing projection
Expand Down Expand Up @@ -749,9 +749,9 @@ def _project_images(reference_sources, estimated_source, flen, G=None):
estimated_source = np.hstack(
(estimated_source.transpose(), np.zeros((nchan, flen - 1)))
)
n_fft = int(2 ** np.ceil(np.log2(nsampl + flen - 1.0)))
sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1)
sef = scipy.fftpack.fft(estimated_source, n=n_fft)
n_fft = scipy.fft.next_fast_len(nsampl + flen - 1, real=True)
sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=1)
sef = scipy.fft.rfft(estimated_source, n=n_fft)

# inner products between delayed versions of reference_sources
if G is None:
Expand All @@ -760,7 +760,7 @@ def _project_images(reference_sources, estimated_source, flen, G=None):
for i in range(nchan * nsrc):
for j in range(i + 1):
ssf = sf[i] * np.conj(sf[j])
ssf = np.real(scipy.fftpack.ifft(ssf))
ssf = scipy.fft.irfft(ssf)
ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), r=ssf[:flen])
G[i * flen : (i + 1) * flen, j * flen : (j + 1) * flen] = ss
G[j * flen : (j + 1) * flen, i * flen : (i + 1) * flen] = ss.T
Expand All @@ -771,7 +771,7 @@ def _project_images(reference_sources, estimated_source, flen, G=None):
for i in range(nchan * nsrc):
for j in range(i + 1):
ssf = sf[i] * np.conj(sf[j])
ssf = np.real(scipy.fftpack.ifft(ssf))
ssf = scipy.fft.irfft(ssf)
ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), r=ssf[:flen])
G[i * flen : (i + 1) * flen, j * flen : (j + 1) * flen] = ss
G[j * flen : (j + 1) * flen, i * flen : (i + 1) * flen] = ss.T
Expand All @@ -782,7 +782,7 @@ def _project_images(reference_sources, estimated_source, flen, G=None):
for k in range(nchan * nsrc):
for i in range(nchan):
ssef = sf[k] * np.conj(sef[i])
ssef = np.real(scipy.fftpack.ifft(ssef))
ssef = scipy.fft.irfft(ssef)
D[k * flen : (k + 1) * flen, i] = np.hstack(
(ssef[0], ssef[-1:-flen:-1])
).transpose()
Expand Down

0 comments on commit a658f4e

Please sign in to comment.