Skip to content

Commit

Permalink
refactor histogram buffer to base class
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Böck committed Sep 3, 2017
1 parent c34a6e8 commit 8777d30
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions madmom/features/tempo.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ class TempoHistogramProcessor(OnlineProcessor):
Minimum tempo to detect [bpm].
max_bpm : float
Maximum tempo to detect [bpm].
hist_buffer : float
Aggregate the tempo histogram over `hist_buffer` seconds.
fps : float, optional
Frames per second.
Expand All @@ -268,12 +270,17 @@ class TempoHistogramProcessor(OnlineProcessor):
"""

def __init__(self, min_bpm, max_bpm, fps=None, online=False, **kwargs):
def __init__(self, min_bpm, max_bpm, hist_buffer=HIST_BUFFER, fps=None,
online=False, **kwargs):
# pylint: disable=unused-argument
super(TempoHistogramProcessor, self).__init__(online=online)
self.min_bpm = min_bpm
self.max_bpm = max_bpm
self.hist_buffer = hist_buffer
self.fps = fps
if self.online:
self._hist_buffer = BufferProcessor((int(hist_buffer * self.fps),
len(self.intervals)))

@property
def min_interval(self):
Expand All @@ -290,6 +297,10 @@ def intervals(self):
"""Beat intervals [frames]."""
return np.arange(self.min_interval, self.max_interval + 1)

def reset(self):
"""Reset the tempo histogram aggregation buffer."""
self._hist_buffer.reset()


class CombFilterTempoHistogramProcessor(TempoHistogramProcessor):
"""
Expand All @@ -303,9 +314,8 @@ class CombFilterTempoHistogramProcessor(TempoHistogramProcessor):
Maximum tempo to detect [bpm].
alpha : float, optional
Scaling factor for the comb filter.
hist_buffer : float, optional
Use a buffer of this size to sum the max. bins in online mode
[seconds].
hist_buffer : float
Aggregate the tempo histogram over `hist_buffer` seconds.
fps : float, optional
Frames per second.
online : bool, optional
Expand All @@ -317,18 +327,17 @@ def __init__(self, min_bpm=MIN_BPM, max_bpm=MAX_BPM, alpha=ALPHA,
hist_buffer=HIST_BUFFER, fps=None, online=False, **kwargs):
# pylint: disable=unused-argument
super(CombFilterTempoHistogramProcessor, self).__init__(
min_bpm=min_bpm, max_bpm=max_bpm, fps=fps, online=online, **kwargs)
min_bpm=min_bpm, max_bpm=max_bpm, hist_buffer=hist_buffer, fps=fps,
online=online, **kwargs)
self.alpha = alpha
if self.online:
self._comb_buffer = BufferProcessor((self.max_interval + 1,
len(self.intervals)))
self._hist_buffer = BufferProcessor((int(hist_buffer * self.fps),
len(self.intervals)))

def reset(self):
"""Reset to initial state."""
super(CombFilterTempoHistogramProcessor, self).reset()
self._comb_buffer.reset()
self._hist_buffer.reset()

def process_offline(self, activations, **kwargs):
"""
Expand Down Expand Up @@ -403,9 +412,8 @@ class ACFTempoHistogramProcessor(TempoHistogramProcessor):
Minimum tempo to detect [bpm].
max_bpm : float, optional
Maximum tempo to detect [bpm].
hist_buffer : float, optional
Use a buffer of this size for the activations to calculate the
auto-correlation function [seconds].
hist_buffer : float
Aggregate the tempo histogram over `hist_buffer` seconds.
fps : float, optional
Frames per second.
online : bool, optional
Expand All @@ -417,16 +425,15 @@ def __init__(self, min_bpm=MIN_BPM, max_bpm=MAX_BPM,
hist_buffer=HIST_BUFFER, fps=None, online=False, **kwargs):
# pylint: disable=unused-argument
super(ACFTempoHistogramProcessor, self).__init__(
min_bpm=min_bpm, max_bpm=max_bpm, fps=fps, online=online, **kwargs)
min_bpm=min_bpm, max_bpm=max_bpm, hist_buffer=hist_buffer, fps=fps,
online=online, **kwargs)
if self.online:
self._act_buffer = BufferProcessor((self.max_interval + 1, 1))
self._hist_buffer = BufferProcessor((int(hist_buffer * self.fps),
len(self.intervals)))

def reset(self):
"""Reset to initial state."""
super(ACFTempoHistogramProcessor, self).reset()
self._act_buffer.reset()
self._hist_buffer.reset()

def process_offline(self, activations, **kwargs):
"""
Expand Down Expand Up @@ -510,19 +517,17 @@ def __init__(self, min_bpm=MIN_BPM, max_bpm=MAX_BPM,
hist_buffer=HIST_BUFFER, fps=None, online=False, **kwargs):
# pylint: disable=unused-argument
super(DBNTempoHistogramProcessor, self).__init__(
min_bpm=min_bpm, max_bpm=max_bpm, fps=fps, online=online, **kwargs)
min_bpm=min_bpm, max_bpm=max_bpm, hist_buffer=hist_buffer, fps=fps,
online=online, **kwargs)
from .beats import DBNBeatTrackingProcessor
self.dbn = DBNBeatTrackingProcessor(
min_bpm=self.min_bpm, max_bpm=self.max_bpm, fps=self.fps,
online=online, **kwargs)
if self.online:
self._hist_buffer = BufferProcessor((int(hist_buffer * self.fps),
len(self.intervals)))

def reset(self):
"""Reset DBN to initial state."""
super(DBNTempoHistogramProcessor, self).reset()
self.dbn.hmm.reset()
self._hist_buffer.reset()

def process_offline(self, activations, **kwargs):
"""
Expand Down

0 comments on commit 8777d30

Please sign in to comment.