Skip to content

Commit

Permalink
add DrumTranscriptor program
Browse files Browse the repository at this point in the history
  • Loading branch information
Stefan authored and Sebastian Böck committed Aug 28, 2017
1 parent 2293371 commit 4734ee9
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
86 changes: 86 additions & 0 deletions bin/DrumTranscriptor
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/usr/bin/env python
# encoding: utf-8
"""
Drum transcription with a convolutional recurrent neural network (CRNN).
"""

from __future__ import absolute_import, division, print_function

import argparse

from madmom.features import ActivationsProcessor
from madmom.features.drums import CRNNDrumProcessor, DrumPeakPickingProcessor
from madmom.features.notes import write_midi, write_notes
from madmom.processors import IOProcessor, io_arguments


def main():
"""DrumTranscriptor"""

# define parser
p = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter, description='''
Drum transcription with a convolutional recurrent neural network (CRNN).
''')
# version
p.add_argument('--version', action='version',
version='DrumTranscriptor.2017')
# input/output arguments
io_arguments(p, output_suffix='.drum_transcriptor.txt')
ActivationsProcessor.add_arguments(p)
# peak picking arguments
DrumPeakPickingProcessor.add_arguments(
p, threshold=0.15, smooth=0, pre_avg=0.1, post_avg=0.01, pre_max=0.02,
post_max=0.01, combine=0.02)
# midi arguments
p.add_argument('--midi', dest='output_format', action='store_const',
const='midi', help='save as MIDI')

# parse arguments
args = p.parse_args()

# set immutable defaults
args.fps = 100

# set the suffix for midi files
if args.output_format == 'midi':
args.output_suffix = '.mid'

# print arguments
if args.verbose:
print(args)

# input processor
if args.load:
# load the activations from file
in_processor = ActivationsProcessor(mode='r', **vars(args))
else:
# use a RNN to predict the notes
in_processor = CRNNDrumProcessor(**vars(args))

# output processor
if args.save:
# save the RNN note activations to file
out_processor = ActivationsProcessor(mode='w', **vars(args))
else:
# perform peak picking on the activation function
peak_picking = DrumPeakPickingProcessor(**vars(args))
# output everything in the right format
if args.output_format is None:
output = write_notes
elif args.output_format == 'midi':
output = write_midi
else:
raise ValueError('unknown output format: %s' % args.output_format)
out_processor = [peak_picking, output]

# create an IOProcessor
processor = IOProcessor(in_processor, out_processor)

# and call the processing function
args.func(processor, **vars(args))


if __name__ == '__main__':
main()
41 changes: 41 additions & 0 deletions tests/test_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,47 @@ def test_run(self):
self.assertTrue(np.allclose(result, self.result, atol=1e-5))


class TestDrumTranscriptorProgram(unittest.TestCase):
def setUp(self):
self.bin = pj(program_path, "DrumTranscriptor")
self.activations = Activations(
pj(ACTIVATIONS_PATH, "sample.drums_crnn.npz"))
self.result = np.loadtxt(
pj(DETECTIONS_PATH, "sample.drum_transcriptor.txt"))

def test_help(self):
self.assertTrue(run_help(self.bin))

def test_binary(self):
# save activations as binary file
run_program([self.bin, '--save', 'single', sample_file,
'-o', tmp_act])
act = Activations(tmp_act)
self.assertTrue(np.allclose(act, self.activations, atol=1e-5))
self.assertEqual(act.fps, self.activations.fps)
# reload from file
run_program([self.bin, '--load', 'single', tmp_act, '-o', tmp_result])
result = np.loadtxt(tmp_result)
self.assertTrue(np.allclose(result, self.result, atol=1e-5))

def test_txt(self):
# save activations as txt file
run_program([self.bin, '--save', '--sep', ' ', 'single',
sample_file, '-o', tmp_act])
act = Activations(tmp_act, sep=' ', fps=100)
self.assertTrue(np.allclose(act, self.activations, atol=1e-5))
# reload from file
run_program([self.bin, '--load', '--sep', ' ', 'single', tmp_act,
'-o', tmp_result])
result = np.loadtxt(tmp_result)
self.assertTrue(np.allclose(result, self.result, atol=1e-5))

def test_run(self):
run_program([self.bin, 'single', sample_file, '-o', tmp_result])
result = np.loadtxt(tmp_result)
self.assertTrue(np.allclose(result, self.result, atol=1e-5))


class TestSpectralOnsetDetectionProgram(unittest.TestCase):
def setUp(self):
self.bin = pj(program_path, "SpectralOnsetDetection")
Expand Down

0 comments on commit 4734ee9

Please sign in to comment.