From 4734ee912ee161e3bd1d8f6aa7c64bdce17504d1 Mon Sep 17 00:00:00 2001 From: Stefan Date: Thu, 24 Aug 2017 00:32:03 +0200 Subject: [PATCH] add DrumTranscriptor program --- bin/DrumTranscriptor | 86 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_bin.py | 41 +++++++++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 bin/DrumTranscriptor diff --git a/bin/DrumTranscriptor b/bin/DrumTranscriptor new file mode 100644 index 000000000..3bb3918ae --- /dev/null +++ b/bin/DrumTranscriptor @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# encoding: utf-8 +""" +Drum transcription with a convolutional recurrent neural network (CRNN). + +""" + +from __future__ import absolute_import, division, print_function + +import argparse + +from madmom.features import ActivationsProcessor +from madmom.features.drums import CRNNDrumProcessor, DrumPeakPickingProcessor +from madmom.features.notes import write_midi, write_notes +from madmom.processors import IOProcessor, io_arguments + + +def main(): + """DrumTranscriptor""" + + # define parser + p = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, description=''' + Drum transcription with a convolutional recurrent neural network (CRNN). + ''') + # version + p.add_argument('--version', action='version', + version='DrumTranscriptor.2017') + # input/output arguments + io_arguments(p, output_suffix='.drum_transcriptor.txt') + ActivationsProcessor.add_arguments(p) + # peak picking arguments + DrumPeakPickingProcessor.add_arguments( + p, threshold=0.15, smooth=0, pre_avg=0.1, post_avg=0.01, pre_max=0.02, + post_max=0.01, combine=0.02) + # midi arguments + p.add_argument('--midi', dest='output_format', action='store_const', + const='midi', help='save as MIDI') + + # parse arguments + args = p.parse_args() + + # set immutable defaults + args.fps = 100 + + # set the suffix for midi files + if args.output_format == 'midi': + args.output_suffix = '.mid' + + # print arguments + if args.verbose: + print(args) + + # input processor + if args.load: + # load the activations from file + in_processor = ActivationsProcessor(mode='r', **vars(args)) + else: + # use a RNN to predict the notes + in_processor = CRNNDrumProcessor(**vars(args)) + + # output processor + if args.save: + # save the RNN note activations to file + out_processor = ActivationsProcessor(mode='w', **vars(args)) + else: + # perform peak picking on the activation function + peak_picking = DrumPeakPickingProcessor(**vars(args)) + # output everything in the right format + if args.output_format is None: + output = write_notes + elif args.output_format == 'midi': + output = write_midi + else: + raise ValueError('unknown output format: %s' % args.output_format) + out_processor = [peak_picking, output] + + # create an IOProcessor + processor = IOProcessor(in_processor, out_processor) + + # and call the processing function + args.func(processor, **vars(args)) + + +if __name__ == '__main__': + main() diff --git a/tests/test_bin.py b/tests/test_bin.py index f0d01f4ae..8bfb1e7f3 100644 --- a/tests/test_bin.py +++ b/tests/test_bin.py @@ -742,6 +742,47 @@ def test_run(self): self.assertTrue(np.allclose(result, self.result, atol=1e-5)) +class TestDrumTranscriptorProgram(unittest.TestCase): + def setUp(self): + self.bin = pj(program_path, "DrumTranscriptor") + self.activations = Activations( + pj(ACTIVATIONS_PATH, "sample.drums_crnn.npz")) + self.result = np.loadtxt( + pj(DETECTIONS_PATH, "sample.drum_transcriptor.txt")) + + def test_help(self): + self.assertTrue(run_help(self.bin)) + + def test_binary(self): + # save activations as binary file + run_program([self.bin, '--save', 'single', sample_file, + '-o', tmp_act]) + act = Activations(tmp_act) + self.assertTrue(np.allclose(act, self.activations, atol=1e-5)) + self.assertEqual(act.fps, self.activations.fps) + # reload from file + run_program([self.bin, '--load', 'single', tmp_act, '-o', tmp_result]) + result = np.loadtxt(tmp_result) + self.assertTrue(np.allclose(result, self.result, atol=1e-5)) + + def test_txt(self): + # save activations as txt file + run_program([self.bin, '--save', '--sep', ' ', 'single', + sample_file, '-o', tmp_act]) + act = Activations(tmp_act, sep=' ', fps=100) + self.assertTrue(np.allclose(act, self.activations, atol=1e-5)) + # reload from file + run_program([self.bin, '--load', '--sep', ' ', 'single', tmp_act, + '-o', tmp_result]) + result = np.loadtxt(tmp_result) + self.assertTrue(np.allclose(result, self.result, atol=1e-5)) + + def test_run(self): + run_program([self.bin, 'single', sample_file, '-o', tmp_result]) + result = np.loadtxt(tmp_result) + self.assertTrue(np.allclose(result, self.result, atol=1e-5)) + + class TestSpectralOnsetDetectionProgram(unittest.TestCase): def setUp(self): self.bin = pj(program_path, "SpectralOnsetDetection")