From 7997fdf3972f992209eaf144f37c18926fa3c960 Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Mon, 25 Mar 2024 14:53:52 -0400 Subject: [PATCH] Test and CI modernization (#370) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test_alignments → pytest * test_key → pytest * test_onset → pytest * test_patterns → pytest * updating sonify tests * transcription velocity updates * fixed a hack in onset test * test_beats → pytest * test_beats → pytest * test_beats → pytest * updating CI configs to run again * adding scipy config dump to ci action * reverting some working direcory business * modernized test_segment * migrated to setup.cfg * adding minimal dependency environment for CI * fixing ci spec * ignore coverage files * modernized multipitch tests * temporarily disabling concurrency failures to test-drive new environment configs * py312 issues * modernized test_tempo * removed unused import * test_transcription modernized * removed nose from display test * fixed display test code * modernized melody tests * modernized util tests * modernized test_io * modernized hierarchy tests * modernized separation tests * chord test translation in progress * chord test translation in progress * chord test translation in progress * test modernization complete * blacked the tests * Adding linter action * ingesting style check configs * mild spellchecking * Fixed spellcheck dictionary * blaked package code * velinized docstrings * velinized docstrings * docstyle checks on util and key * docstyle updatse * docstyle updates for chord * docstyle updates * velinizing again * docstyle on patterns * docstyle pass * black updates * don't require bz2 for conda packages * let's see if tv tests pass on minimal env * let's see if tv tests pass on minimal env * disabling group concurrency for now * trying once more with disbaling fail-fast * skipping broken chord sonification test for now * skipping display tests for now * blacking tests again * explicitizing a scalar conversion in goto metric to avoid numpy warning * dancing around fancy index deprecations in numpy * dancing around fancy index deprecations in numpy * suppressing noisy warnings in melody * making segment empty tests more precise * catching more warnings in separation test * blacking tests * bumpy min numpy version, redoing transscription velocity fixture 2 * formatting * updated minimal scipy to 1.4 * bumping min matplotlib as well * bumping min matplotlib as well * bumping matplotlib to 3.0 (2018) * making minimal numpy match scipy requirement * Trying mpl-base instead of full package for test environment * Maybe mpl 3.3 will work * bumping numpy to 1.15.4 * removing janky warn nesting in separation test * forgot to modernize multipitch tests * update action versions * fixed content type for setup.cfg description --- .codespell_ignore_list | 5 + .github/environment-lint.yml | 21 + .github/environment-minimal.yml | 12 + .github/environment.yml | 6 +- .github/workflows/lint_python.yml | 71 ++ .github/workflows/test.yml | 55 +- .gitignore | 6 +- mir_eval/__init__.py | 2 +- mir_eval/alignment.py | 19 +- mir_eval/beat.py | 320 +++--- mir_eval/chord.py | 362 +++---- mir_eval/display.py | 306 +++--- mir_eval/hierarchy.py | 272 +++-- mir_eval/io.py | 198 ++-- mir_eval/key.py | 97 +- mir_eval/melody.py | 274 ++--- mir_eval/multipitch.py | 159 +-- mir_eval/onset.py | 29 +- mir_eval/pattern.py | 250 ++--- mir_eval/segment.py | 471 +++++---- mir_eval/separation.py | 427 ++++---- mir_eval/sonify.py | 132 ++- mir_eval/tempo.py | 60 +- mir_eval/transcription.py | 269 +++-- mir_eval/transcription_velocity.py | 191 +++- mir_eval/util.py | 221 ++-- setup.cfg | 60 ++ setup.py | 39 +- .../data/transcription_velocity/output2.json | 2 +- tests/generate_data.py | 97 +- tests/test_alignment.py | 108 +- tests/test_beat.py | 194 ++-- tests/test_chord.py | 981 +++++++++--------- tests/test_display.py | 239 +++-- tests/test_hierarchy.py | 302 +++--- tests/test_input_output.py | 181 ++-- tests/test_key.py | 110 +- tests/test_melody.py | 472 ++++----- tests/test_multipitch.py | 238 +++-- tests/test_onset.py | 130 +-- tests/test_pattern.py | 141 ++- tests/test_segment.py | 327 +++--- tests/test_separation.py | 544 ++++++---- tests/test_sonify.py | 124 ++- tests/test_tempo.py | 151 +-- tests/test_transcription.py | 333 +++--- tests/test_transcription_velocity.py | 156 +-- tests/test_util.py | 356 ++++--- 48 files changed, 5108 insertions(+), 4412 deletions(-) create mode 100644 .codespell_ignore_list create mode 100644 .github/environment-lint.yml create mode 100644 .github/environment-minimal.yml create mode 100644 .github/workflows/lint_python.yml create mode 100644 setup.cfg diff --git a/.codespell_ignore_list b/.codespell_ignore_list new file mode 100644 index 00000000..87defebc --- /dev/null +++ b/.codespell_ignore_list @@ -0,0 +1,5 @@ +nce +fpr +shepard +dum +theis diff --git a/.github/environment-lint.yml b/.github/environment-lint.yml new file mode 100644 index 00000000..ba2e6838 --- /dev/null +++ b/.github/environment-lint.yml @@ -0,0 +1,21 @@ +name: lint +channels: + - conda-forge + - defaults +dependencies: + # required + - pip + - bandit + - codespell + - flake8 + - pytest + - pydocstyle + + # Dependencies for velin + - numpydoc>=1.1.0 + - sphinx>=5.1.0 + - pygments + - black + + - pip: + - velin diff --git a/.github/environment-minimal.yml b/.github/environment-minimal.yml new file mode 100644 index 00000000..6e5560da --- /dev/null +++ b/.github/environment-minimal.yml @@ -0,0 +1,12 @@ +name: test +channels: + - conda-forge + - defaults +dependencies: + - pip + - numpy ==1.15.4 + - scipy ==1.4.0 + - matplotlib-base==3.3.0 + - pytest + - pytest-cov + - pytest-mpl diff --git a/.github/environment.yml b/.github/environment.yml index a499d7b5..68641776 100644 --- a/.github/environment.yml +++ b/.github/environment.yml @@ -4,9 +4,9 @@ channels: - defaults dependencies: - pip - - numpy - - scipy - - matplotlib + - numpy >=1.15.4 + - scipy >=1.4.0 + - matplotlib-base>=3.3.0 - pytest - pytest-cov - pytest-mpl diff --git a/.github/workflows/lint_python.yml b/.github/workflows/lint_python.yml new file mode 100644 index 00000000..48005966 --- /dev/null +++ b/.github/workflows/lint_python.yml @@ -0,0 +1,71 @@ +name: lint_python +on: [pull_request, push] +jobs: + lint_python: + name: "Lint and code analysis" + runs-on: ubuntu-latest + strategy: + fail-fast: true + matrix: + include: + - os: ubuntu-latest + python-version: "3.11" + channel-priority: "flexible" + envfile: ".github/environment-lint.yml" + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Cache conda + uses: actions/cache@v4 + env: + CACHE_NUMBER: 0 + with: + path: ~/conda_pkgs_dir + key: ${{ runner.os }}-${{ matrix.python-version }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles( matrix.envfile ) }} + - name: Install conda environmnent + uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: false + python-version: ${{ matrix.python-version }} + add-pip-as-python-dependency: true + auto-activate-base: false + activate-environment: lint + # mamba-version: "*" + channel-priority: ${{ matrix.channel-priority }} + environment-file: ${{ matrix.envfile }} + use-only-tar-bz2: false + + - name: Conda info + shell: bash -l {0} + run: | + conda info -a + conda list + + - name: Spell check package + shell: bash -l {0} + run: codespell --ignore-words .codespell_ignore_list mir_eval + + - name: Security check + shell: bash -l {0} + run: bandit --recursive --skip B101,B110 . + + - name: Style check package + shell: bash -l {0} + run: python -m flake8 mir_eval + + - name: Format check package + shell: bash -l {0} + run: python -m black --check mir_eval + + - name: Format check tests + shell: bash -l {0} + run: python -m black --check tests + + - name: Docstring check + shell: bash -l {0} + run: python -m velin --check mir_eval + + - name: Docstring style check + shell: bash -l {0} + run: python -m pydocstyle mir_eval diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 305de62e..80c64e3b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -3,48 +3,58 @@ name: Test Python code on: pull_request: branches: - - master + - main push: branches: - - master + - main concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: True + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: True jobs: test: strategy: + fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.10", "3.11"] channel-priority: [strict] + envfile: [".github/environment.yml"] include: - - python-version: "3.10" + - python-version: "3.12" os: macos-latest - - python-version: "3.10" + - python-version: "3.12" os: windows-latest - - python-version: "3.10" + - python-version: "3.12" os: ubuntu-latest channel-priority: flexible + - os: ubuntu-latest + python-version: "3.7" + envfile: ".github/environment-minimal.yml" + channel-priority: "flexible" + name: "Minimal dependencies" + runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Cache conda packages - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/conda_pkgs_dir - key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('.github/environment.yml') }} + key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles( matrix.envfile ) }} - name: Create conda environment - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: python-version: ${{ matrix.python-version }} auto-activate-base: false channel-priority: ${{ matrix.channel-priority }} - environment-file: .github/environment.yml - use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! + environment-file: ${{ matrix.envfile }} + # Disabling bz2 to get more recent dependencies. + # NOTE: this breaks cache support, so CI will be slower. + use-only-tar-bz2: False # IMPORTANT: This needs to be set for caching to work properly! - name: Install package in development mode shell: bash -l {0} @@ -60,7 +70,22 @@ jobs: shell: bash -l {0} run: python -c "import numpy; numpy.show_config()" + - name: Show libraries in the system on which SciPy was built + shell: bash -l {0} + run: python -c "import scipy; scipy.show_config()" + - name: Run unit tests shell: bash -l {0} - run: pytest --cov=mir_eval + run: pytest working-directory: tests + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./coverage.xml + flags: unittests + env_vars: OS,PYTHON + name: codecov-umbrella + fail_ci_if_error: true + verbose: true diff --git a/.gitignore b/.gitignore index 53e602e3..6eece87b 100644 --- a/.gitignore +++ b/.gitignore @@ -44,5 +44,9 @@ Thumbs.db # docs docs/_build/* -# matplotlib tsets +# matplotlib tests tests/result_images/* + +# coverage +coverage.xml + diff --git a/mir_eval/__init__.py b/mir_eval/__init__.py index bf50c2d7..1eafe549 100644 --- a/mir_eval/__init__.py +++ b/mir_eval/__init__.py @@ -20,4 +20,4 @@ from . import transcription_velocity from . import key -__version__ = '0.7' +__version__ = "0.7" diff --git a/mir_eval/alignment.py b/mir_eval/alignment.py index 7f32c527..e8bf22fe 100644 --- a/mir_eval/alignment.py +++ b/mir_eval/alignment.py @@ -56,10 +56,8 @@ from mir_eval.util import filter_kwargs -def validate( - reference_timestamps: np.ndarray, estimated_timestamps: np.ndarray -): - """Checks that the input annotations to a metric look like valid onset time +def validate(reference_timestamps: np.ndarray, estimated_timestamps: np.ndarray): + """Check that the input annotations to a metric look like valid onset time arrays, and throws helpful errors if not. Parameters @@ -103,13 +101,9 @@ def validate( # Check monotonicity if not np.all(reference_timestamps[1:] - reference_timestamps[:-1] >= 0): - raise ValueError( - "Reference timestamps are not monotonically increasing!" - ) + raise ValueError("Reference timestamps are not monotonically increasing!") if not np.all(estimated_timestamps[1:] - estimated_timestamps[:-1] >= 0): - raise ValueError( - "Estimated timestamps are not monotonically increasing!" - ) + raise ValueError("Estimated timestamps are not monotonically increasing!") # Check positivity (need for correct PCS metric calculation) if not np.all(reference_timestamps >= 0): @@ -181,7 +175,7 @@ def percentage_correct(reference_timestamps, estimated_timestamps, window=0.3): def percentage_correct_segments( reference_timestamps, estimated_timestamps, duration: Optional[float] = None ): - """Calculates the percentage of correct segments (PCS) metric. + """Calculate the percentage of correct segments (PCS) metric. It constructs segments out of predicted and estimated timestamps separately out of each given timestamp vector and calculates the percentage of overlap between correct @@ -317,6 +311,7 @@ def karaoke_perceptual_metric(reference_timestamps, estimated_timestamps): def evaluate(reference_timestamps, estimated_timestamps, **kwargs): """Compute all metrics for the given reference and estimated annotations. + Examples -------- >>> reference_timestamps = mir_eval.io.load_events('reference.txt') @@ -330,7 +325,7 @@ def evaluate(reference_timestamps, estimated_timestamps, **kwargs): reference timestamp locations, in seconds estimated_timestamps : np.ndarray estimated timestamp locations, in seconds - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. diff --git a/mir_eval/beat.py b/mir_eval/beat.py index 6a34c317..21211a07 100644 --- a/mir_eval/beat.py +++ b/mir_eval/beat.py @@ -1,4 +1,4 @@ -''' +r""" The aim of a beat detection algorithm is to report the times at which a typical human listener might tap their foot to a piece of music. As a result, most metrics for evaluating the performance of beat tracking systems involve @@ -42,7 +42,7 @@ * :func:`mir_eval.beat.information_gain`: The Information Gain of a normalized beat error histogram over a uniform distribution -''' +""" import numpy as np import collections @@ -51,11 +51,11 @@ # The maximum allowable beat time -MAX_TIME = 30000. +MAX_TIME = 30000.0 -def trim_beats(beats, min_beat_time=5.): - """Removes beats before min_beat_time. A common preprocessing step. +def trim_beats(beats, min_beat_time=5.0): + """Remove beats before min_beat_time. A common preprocessing step. Parameters ---------- @@ -75,7 +75,7 @@ def trim_beats(beats, min_beat_time=5.): def validate(reference_beats, estimated_beats): - """Checks that the input annotations to a metric look like valid beat time + """Check that the input annotations to a metric look like valid beat time arrays, and throws helpful errors if not. Parameters @@ -115,27 +115,25 @@ def _get_reference_beat_variations(reference_beats): Half tempo, odd beats half_even : np.ndarray Half tempo, even beats - """ - # Create annotations at twice the metric level - interpolated_indices = np.arange(0, reference_beats.shape[0]-.5, .5) + interpolated_indices = np.arange(0, reference_beats.shape[0] - 0.5, 0.5) original_indices = np.arange(0, reference_beats.shape[0]) - double_reference_beats = np.interp(interpolated_indices, - original_indices, - reference_beats) + double_reference_beats = np.interp( + interpolated_indices, original_indices, reference_beats + ) # Return metric variations: # True, off-beat, double tempo, half tempo odd, and half tempo even - return (reference_beats, - double_reference_beats[1::2], - double_reference_beats, - reference_beats[::2], - reference_beats[1::2]) + return ( + reference_beats, + double_reference_beats[1::2], + double_reference_beats, + reference_beats[::2], + reference_beats[1::2], + ) -def f_measure(reference_beats, - estimated_beats, - f_measure_threshold=0.07): +def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07): """Compute the F-measure of correct vs incorrectly predicted beats. "Correctness" is determined over a small window. @@ -167,20 +165,16 @@ def f_measure(reference_beats, validate(reference_beats, estimated_beats) # When estimated beats are empty, no beats are correct; metric is 0 if estimated_beats.size == 0 or reference_beats.size == 0: - return 0. + return 0.0 # Compute the best-case matching between reference and estimated locations - matching = util.match_events(reference_beats, - estimated_beats, - f_measure_threshold) + matching = util.match_events(reference_beats, estimated_beats, f_measure_threshold) - precision = float(len(matching))/len(estimated_beats) - recall = float(len(matching))/len(reference_beats) + precision = float(len(matching)) / len(estimated_beats) + recall = float(len(matching)) / len(reference_beats) return util.f_measure(precision, recall) -def cemgil(reference_beats, - estimated_beats, - cemgil_sigma=0.04): +def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04): """Cemgil's score, computes a gaussian error of each estimated beat. Compares against the original beat times and all metrical variations. @@ -213,7 +207,7 @@ def cemgil(reference_beats, validate(reference_beats, estimated_beats) # When estimated beats are empty, no beats are correct; metric is 0 if estimated_beats.size == 0 or reference_beats.size == 0: - return 0., 0. + return 0.0, 0.0 # We'll compute Cemgil's accuracy for each variation accuracies = [] for reference_beats in _get_reference_beat_variations(reference_beats): @@ -223,9 +217,9 @@ def cemgil(reference_beats, # Find the error for the closest beat to the reference beat beat_diff = np.min(np.abs(beat - estimated_beats)) # Add gaussian error into the accuracy - accuracy += np.exp(-(beat_diff**2)/(2.0*cemgil_sigma**2)) + accuracy += np.exp(-(beat_diff**2) / (2.0 * cemgil_sigma**2)) # Normalize the accuracy - accuracy /= .5*(estimated_beats.shape[0] + reference_beats.shape[0]) + accuracy /= 0.5 * (estimated_beats.shape[0] + reference_beats.shape[0]) # Add it to our list of accuracy scores accuracies.append(accuracy) # Return raw accuracy with non-varied annotations @@ -233,11 +227,9 @@ def cemgil(reference_beats, return accuracies[0], np.max(accuracies) -def goto(reference_beats, - estimated_beats, - goto_threshold=0.35, - goto_mu=0.2, - goto_sigma=0.2): +def goto( + reference_beats, estimated_beats, goto_threshold=0.35, goto_mu=0.2, goto_sigma=0.2 +): """Calculate Goto's score, a binary 1 or 0 depending on some specific heuristic criteria @@ -275,25 +267,26 @@ def goto(reference_beats, validate(reference_beats, estimated_beats) # When estimated beats are empty, no beats are correct; metric is 0 if estimated_beats.size == 0 or reference_beats.size == 0: - return 0. + return 0.0 # Error for each beat beat_error = np.ones(reference_beats.shape[0]) # Flag for whether the reference and estimated beats are paired paired = np.zeros(reference_beats.shape[0]) # Keep track of Goto's three criteria goto_criteria = 0 - for n in range(1, reference_beats.shape[0]-1): + for n in range(1, reference_beats.shape[0] - 1): # Get previous inner-reference-beat-interval - previous_interval = 0.5*(reference_beats[n] - reference_beats[n-1]) + previous_interval = 0.5 * (reference_beats[n] - reference_beats[n - 1]) # Window start - in the middle of the current beat and the previous window_min = reference_beats[n] - previous_interval # Next inter-reference-beat-interval - next_interval = 0.5*(reference_beats[n+1] - reference_beats[n]) + next_interval = 0.5 * (reference_beats[n + 1] - reference_beats[n]) # Window end - in the middle of the current beat and the next window_max = reference_beats[n] + next_interval # Get estimated beats in the window - beats_in_window = np.logical_and((estimated_beats >= window_min), - (estimated_beats < window_max)) + beats_in_window = np.logical_and( + (estimated_beats >= window_min), (estimated_beats < window_max) + ) # False negative/positive if beats_in_window.sum() == 0 or beats_in_window.sum() > 1: paired[n] = 0 @@ -305,39 +298,36 @@ def goto(reference_beats, offset = estimated_beats[beats_in_window] - reference_beats[n] # Scale by previous or next interval if offset < 0: - beat_error[n] = offset/previous_interval + beat_error[n] = offset[0] / previous_interval else: - beat_error[n] = offset/next_interval + beat_error[n] = offset[0] / next_interval # Get indices of incorrect beats incorrect_beats = np.flatnonzero(np.abs(beat_error) > goto_threshold) # All beats are correct (first and last will be 0 so always correct) if incorrect_beats.shape[0] < 3: # Get the track of correct beats - track = beat_error[incorrect_beats[0] + 1:incorrect_beats[-1] - 1] + track = beat_error[incorrect_beats[0] + 1 : incorrect_beats[-1] - 1] goto_criteria = 1 else: # Get the track of maximal length track_len = np.max(np.diff(incorrect_beats)) track_start = np.flatnonzero(np.diff(incorrect_beats) == track_len)[0] # Is the track length at least 25% of the song? - if track_len - 1 > .25*(reference_beats.shape[0] - 2): + if track_len - 1 > 0.25 * (reference_beats.shape[0] - 2): goto_criteria = 1 start_beat = incorrect_beats[track_start] end_beat = incorrect_beats[track_start + 1] - track = beat_error[start_beat:end_beat + 1] + track = beat_error[start_beat : end_beat + 1] # If we have a track if goto_criteria: # Are mean and std of the track less than the required thresholds? - if np.mean(np.abs(track)) < goto_mu \ - and np.std(track, ddof=1) < goto_sigma: + if np.mean(np.abs(track)) < goto_mu and np.std(track, ddof=1) < goto_sigma: goto_criteria = 3 # If all criteria are met, score is 100%! - return 1.0*(goto_criteria == 3) + return 1.0 * (goto_criteria == 3) -def p_score(reference_beats, - estimated_beats, - p_score_threshold=0.2): +def p_score(reference_beats, estimated_beats, p_score_threshold=0.2): """Get McKinney's P-score. Based on the autocorrelation of the reference and estimated beats @@ -370,52 +360,59 @@ def p_score(reference_beats, # Warn when only one beat is provided for either estimated or reference, # report a warning if reference_beats.size == 1: - warnings.warn("Only one reference beat was provided, so beat intervals" - " cannot be computed.") + warnings.warn( + "Only one reference beat was provided, so beat intervals" + " cannot be computed." + ) if estimated_beats.size == 1: - warnings.warn("Only one estimated beat was provided, so beat intervals" - " cannot be computed.") + warnings.warn( + "Only one estimated beat was provided, so beat intervals" + " cannot be computed." + ) # When estimated or reference beats have <= 1 beats, can't compute the # metric, so return 0 if estimated_beats.size <= 1 or reference_beats.size <= 1: - return 0. + return 0.0 # Quantize beats to 10ms - sampling_rate = int(1.0/0.010) + sampling_rate = int(1.0 / 0.010) # Shift beats so that the minimum in either sequence is zero offset = min(estimated_beats.min(), reference_beats.min()) estimated_beats = np.array(estimated_beats - offset) reference_beats = np.array(reference_beats - offset) # Get the largest time index - end_point = np.int64(np.ceil(np.max([np.max(estimated_beats), - np.max(reference_beats)]))) + end_point = np.int64( + np.ceil(np.max([np.max(estimated_beats), np.max(reference_beats)])) + ) # Make impulse trains with impulses at beat locations - reference_train = np.zeros(end_point*sampling_rate + 1) - beat_indices = np.ceil(reference_beats*sampling_rate).astype(np.int64) + reference_train = np.zeros(end_point * sampling_rate + 1) + beat_indices = np.ceil(reference_beats * sampling_rate).astype(np.int64) reference_train[beat_indices] = 1.0 - estimated_train = np.zeros(end_point*sampling_rate + 1) - beat_indices = np.ceil(estimated_beats*sampling_rate).astype(np.int64) + estimated_train = np.zeros(end_point * sampling_rate + 1) + beat_indices = np.ceil(estimated_beats * sampling_rate).astype(np.int64) estimated_train[beat_indices] = 1.0 # Window size to take the correlation over # defined as .2*median(inter-annotation-intervals) annotation_intervals = np.diff(np.flatnonzero(reference_train)) - win_size = int(np.round(p_score_threshold*np.median(annotation_intervals))) + win_size = int(np.round(p_score_threshold * np.median(annotation_intervals))) # Get full correlation - train_correlation = np.correlate(reference_train, estimated_train, 'full') + train_correlation = np.correlate(reference_train, estimated_train, "full") # Get the middle element - note we are rounding down on purpose here - middle_lag = train_correlation.shape[0]//2 + middle_lag = train_correlation.shape[0] // 2 # Truncate to only valid lags (those corresponding to the window) start = middle_lag - win_size end = middle_lag + win_size + 1 train_correlation = train_correlation[start:end] # Compute and return the P-score n_beats = np.max([estimated_beats.shape[0], reference_beats.shape[0]]) - return np.sum(train_correlation)/n_beats + return np.sum(train_correlation) / n_beats -def continuity(reference_beats, - estimated_beats, - continuity_phase_threshold=0.175, - continuity_period_threshold=0.175): +def continuity( + reference_beats, + estimated_beats, + continuity_phase_threshold=0.175, + continuity_period_threshold=0.175, +): """Get metrics based on how much of the estimated beat sequence is continually correct. @@ -458,23 +455,26 @@ def continuity(reference_beats, # Warn when only one beat is provided for either estimated or reference, # report a warning if reference_beats.size == 1: - warnings.warn("Only one reference beat was provided, so beat intervals" - " cannot be computed.") + warnings.warn( + "Only one reference beat was provided, so beat intervals" + " cannot be computed." + ) if estimated_beats.size == 1: - warnings.warn("Only one estimated beat was provided, so beat intervals" - " cannot be computed.") + warnings.warn( + "Only one estimated beat was provided, so beat intervals" + " cannot be computed." + ) # When estimated or reference beats have <= 1 beats, can't compute the # metric, so return 0 if estimated_beats.size <= 1 or reference_beats.size <= 1: - return 0., 0., 0., 0. + return 0.0, 0.0, 0.0, 0.0 # Accuracies for each variation continuous_accuracies = [] total_accuracies = [] # Get accuracy for each variation for reference_beats in _get_reference_beat_variations(reference_beats): # Annotations that have been used - n_annotations = np.max([reference_beats.shape[0], - estimated_beats.shape[0]]) + n_annotations = np.max([reference_beats.shape[0], estimated_beats.shape[0]]) used_annotations = np.zeros(n_annotations) # Whether or not we are continuous at any given point beat_successes = np.zeros(n_annotations) @@ -494,13 +494,15 @@ def continuity(reference_beats, # How far is the estimated beat from the reference beat, # relative to the inter-annotation-interval? if nearest + 1 < reference_beats.shape[0]: - reference_interval = (reference_beats[nearest + 1] - - reference_beats[nearest]) + reference_interval = ( + reference_beats[nearest + 1] - reference_beats[nearest] + ) else: # Special case when nearest + 1 is too large - use the # previous interval instead - reference_interval = (reference_beats[nearest] - - reference_beats[nearest - 1]) + reference_interval = ( + reference_beats[nearest] - reference_beats[nearest - 1] + ) # Handle this special case when beats are not unique if reference_interval == 0: if min_difference == 0: @@ -508,17 +510,15 @@ def continuity(reference_beats, else: phase = np.inf else: - phase = np.abs(min_difference/reference_interval) + phase = np.abs(min_difference / reference_interval) # How close is the inter-beat-interval # to the inter-annotation-interval? if m + 1 < estimated_beats.shape[0]: - estimated_interval = (estimated_beats[m + 1] - - estimated_beats[m]) + estimated_interval = estimated_beats[m + 1] - estimated_beats[m] else: # Special case when m + 1 is too large - use the # previous interval - estimated_interval = (estimated_beats[m] - - estimated_beats[m - 1]) + estimated_interval = estimated_beats[m] - estimated_beats[m - 1] # Handle this special case when beats are not unique if reference_interval == 0: if estimated_interval == 0: @@ -526,10 +526,11 @@ def continuity(reference_beats, else: period = np.inf else: - period = \ - np.abs(1 - estimated_interval/reference_interval) - if phase < continuity_phase_threshold and \ - period < continuity_period_threshold: + period = np.abs(1 - estimated_interval / reference_interval) + if ( + phase < continuity_phase_threshold + and period < continuity_period_threshold + ): # Set this annotation as used used_annotations[nearest] = 1 # This beat is matched @@ -538,18 +539,21 @@ def continuity(reference_beats, else: # How far is the estimated beat from the reference beat, # relative to the inter-annotation-interval? - reference_interval = (reference_beats[nearest] - - reference_beats[nearest - 1]) - phase = np.abs(min_difference/reference_interval) + reference_interval = ( + reference_beats[nearest] - reference_beats[nearest - 1] + ) + phase = np.abs(min_difference / reference_interval) # How close is the inter-beat-interval # to the inter-annotation-interval? - estimated_interval = (estimated_beats[m] - - estimated_beats[m - 1]) - reference_interval = (reference_beats[nearest] - - reference_beats[nearest - 1]) - period = np.abs(1 - estimated_interval/reference_interval) - if phase < continuity_phase_threshold and \ - period < continuity_period_threshold: + estimated_interval = estimated_beats[m] - estimated_beats[m - 1] + reference_interval = ( + reference_beats[nearest] - reference_beats[nearest - 1] + ) + period = np.abs(1 - estimated_interval / reference_interval) + if ( + phase < continuity_phase_threshold + and period < continuity_period_threshold + ): # Set this annotation as used used_annotations[nearest] = 1 # This beat is matched @@ -565,21 +569,21 @@ def continuity(reference_beats, beat_successes = beat_successes[1:-1] # Get the continuous accuracy as the longest track of successful beats longest_track = np.max(np.diff(beat_failures)) - 1 - continuous_accuracy = longest_track/(1.0*beat_successes.shape[0]) + continuous_accuracy = longest_track / (1.0 * beat_successes.shape[0]) continuous_accuracies.append(continuous_accuracy) # Get the total accuracy - all sequences - total_accuracy = np.sum(beat_successes)/(1.0*beat_successes.shape[0]) + total_accuracy = np.sum(beat_successes) / (1.0 * beat_successes.shape[0]) total_accuracies.append(total_accuracy) # Grab accuracy scores - return (continuous_accuracies[0], - total_accuracies[0], - np.max(continuous_accuracies), - np.max(total_accuracies)) + return ( + continuous_accuracies[0], + total_accuracies[0], + np.max(continuous_accuracies), + np.max(total_accuracies), + ) -def information_gain(reference_beats, - estimated_beats, - bins=41): +def information_gain(reference_beats, estimated_beats, bins=41): """Get the information gain - K-L divergence of the beat error histogram to a uniform histogram @@ -611,20 +615,25 @@ def information_gain(reference_beats, # If an even number of bins is provided, # there will be no bin centered at zero, so warn the user. if not bins % 2: - warnings.warn("bins parameter is even, " - "so there will not be a bin centered at zero.") + warnings.warn( + "bins parameter is even, " "so there will not be a bin centered at zero." + ) # Warn when only one beat is provided for either estimated or reference, # report a warning if reference_beats.size == 1: - warnings.warn("Only one reference beat was provided, so beat intervals" - " cannot be computed.") + warnings.warn( + "Only one reference beat was provided, so beat intervals" + " cannot be computed." + ) if estimated_beats.size == 1: - warnings.warn("Only one estimated beat was provided, so beat intervals" - " cannot be computed.") + warnings.warn( + "Only one estimated beat was provided, so beat intervals" + " cannot be computed." + ) # When estimated or reference beats have <= 1 beats, can't compute the # metric, so return 0 if estimated_beats.size <= 1 or reference_beats.size <= 1: - return 0. + return 0.0 # Get entropy for reference beats->estimated beats # and estimated beats->reference beats forward_entropy = _get_entropy(reference_beats, estimated_beats, bins) @@ -633,15 +642,18 @@ def information_gain(reference_beats, norm = np.log2(bins) if forward_entropy > backward_entropy: # Note that the beat evaluation toolbox does not normalize - information_gain_score = (norm - forward_entropy)/norm + information_gain_score = (norm - forward_entropy) / norm else: - information_gain_score = (norm - backward_entropy)/norm + information_gain_score = (norm - backward_entropy) / norm return information_gain_score def _get_entropy(reference_beats, estimated_beats, bins): - """Helper function for information gain - (needs to be run twice - once backwards, once forwards) + """Compute the entropy of the beat error histogram. + + This is a helper function for the information gain + metric, and needs to be run twice: once backwards, once + forwards. Parameters ---------- @@ -656,7 +668,6 @@ def _get_entropy(reference_beats, estimated_beats, bins): ------- entropy : float Entropy of beat error histogram - """ beat_error = np.zeros(estimated_beats.shape[0]) for n in range(estimated_beats.shape[0]): @@ -667,34 +678,34 @@ def _get_entropy(reference_beats, estimated_beats, bins): # If the first annotation is closest... if closest_beat == 0: # Inter-annotation interval - space between first two beats - interval = .5*(reference_beats[1] - reference_beats[0]) + interval = 0.5 * (reference_beats[1] - reference_beats[0]) # If last annotation is closest... if closest_beat == (reference_beats.shape[0] - 1): - interval = .5*(reference_beats[-1] - reference_beats[-2]) + interval = 0.5 * (reference_beats[-1] - reference_beats[-2]) else: if absolute_error < 0: # Closest annotation is the one before the current beat # so look at previous inner-annotation-interval start = reference_beats[closest_beat] end = reference_beats[closest_beat - 1] - interval = .5*(start - end) + interval = 0.5 * (start - end) else: # Closest annotation is the one after the current beat # so look at next inner-annotation-interval start = reference_beats[closest_beat + 1] end = reference_beats[closest_beat] - interval = .5*(start - end) + interval = 0.5 * (start - end) # The actual error of this beat - beat_error[n] = .5*absolute_error/interval + beat_error[n] = 0.5 * absolute_error / interval # Put beat errors in range (-.5, .5) - beat_error = np.mod(beat_error + .5, -1) + .5 + beat_error = np.mod(beat_error + 0.5, -1) + 0.5 # Note these are slightly different the beat evaluation toolbox # (they are uniform) - histogram_bin_edges = np.linspace(-.5, .5, bins + 1) + histogram_bin_edges = np.linspace(-0.5, 0.5, bins + 1) # Get the histogram raw_bin_values = np.histogram(beat_error, histogram_bin_edges)[0] # Turn into a proper probability distribution - raw_bin_values = raw_bin_values/(1.0*np.sum(raw_bin_values)) + raw_bin_values = raw_bin_values / (1.0 * np.sum(raw_bin_values)) # Set zero-valued bins to 1 to make the entropy calculation well-behaved raw_bin_values[raw_bin_values == 0] = 1 # Calculate entropy @@ -716,7 +727,7 @@ def evaluate(reference_beats, estimated_beats, **kwargs): Reference beat times, in seconds estimated_beats : np.ndarray Query beat times, in seconds - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -727,7 +738,6 @@ def evaluate(reference_beats, estimated_beats, **kwargs): the value is the (float) score achieved. """ - # Trim beat times at the beginning of the annotations reference_beats = util.filter_kwargs(trim_beats, reference_beats, **kwargs) estimated_beats = util.filter_kwargs(trim_beats, estimated_beats, **kwargs) @@ -737,34 +747,36 @@ def evaluate(reference_beats, estimated_beats, **kwargs): scores = collections.OrderedDict() # F-Measure - scores['F-measure'] = util.filter_kwargs(f_measure, reference_beats, - estimated_beats, **kwargs) + scores["F-measure"] = util.filter_kwargs( + f_measure, reference_beats, estimated_beats, **kwargs + ) # Cemgil - scores['Cemgil'], scores['Cemgil Best Metric Level'] = \ - util.filter_kwargs(cemgil, reference_beats, estimated_beats, **kwargs) + scores["Cemgil"], scores["Cemgil Best Metric Level"] = util.filter_kwargs( + cemgil, reference_beats, estimated_beats, **kwargs + ) # Goto - scores['Goto'] = util.filter_kwargs(goto, reference_beats, - estimated_beats, **kwargs) + scores["Goto"] = util.filter_kwargs( + goto, reference_beats, estimated_beats, **kwargs + ) # P-Score - scores['P-score'] = util.filter_kwargs(p_score, reference_beats, - estimated_beats, **kwargs) + scores["P-score"] = util.filter_kwargs( + p_score, reference_beats, estimated_beats, **kwargs + ) # Continuity metrics - (scores['Correct Metric Level Continuous'], - scores['Correct Metric Level Total'], - scores['Any Metric Level Continuous'], - scores['Any Metric Level Total']) = util.filter_kwargs(continuity, - reference_beats, - estimated_beats, - **kwargs) + ( + scores["Correct Metric Level Continuous"], + scores["Correct Metric Level Total"], + scores["Any Metric Level Continuous"], + scores["Any Metric Level Total"], + ) = util.filter_kwargs(continuity, reference_beats, estimated_beats, **kwargs) # Information gain - scores['Information gain'] = util.filter_kwargs(information_gain, - reference_beats, - estimated_beats, - **kwargs) + scores["Information gain"] = util.filter_kwargs( + information_gain, reference_beats, estimated_beats, **kwargs + ) return scores diff --git a/mir_eval/chord.py b/mir_eval/chord.py index def9195e..69939e7a 100644 --- a/mir_eval/chord.py +++ b/mir_eval/chord.py @@ -1,4 +1,4 @@ -r''' +r""" Chord estimation algorithms produce a list of intervals and labels which denote the chord being played over each timespan. They are evaluated by comparing the estimated chord labels to some reference, usually using a mapping to a chord @@ -67,7 +67,7 @@ entire quality in closed voicing, i.e. spanning only a single octave; extended chords (9's, 11's and 13's) are rolled into a single octave with any upper voices included as extensions. For example, ('A:7', 'A:9') are - equivlent but ('A:7', 'A:maj7') are not. + equivalent but ('A:7', 'A:maj7') are not. * :func:`mir_eval.chord.tetrads_inv`: Same as above, with inversions (bass relationships). @@ -93,7 +93,7 @@ .. [#harte2010towards] C. Harte. Towards Automatic Extraction of Harmony Information from Music Signals. PhD thesis, Queen Mary University of London, August 2010. -''' +""" import numpy as np import warnings @@ -105,15 +105,15 @@ BITMAP_LENGTH = 12 NO_CHORD = "N" -NO_CHORD_ENCODED = -1, np.array([0]*BITMAP_LENGTH), -1 +NO_CHORD_ENCODED = -1, np.array([0] * BITMAP_LENGTH), -1 X_CHORD = "X" -X_CHORD_ENCODED = -1, np.array([-1]*BITMAP_LENGTH), -1 +X_CHORD_ENCODED = -1, np.array([-1] * BITMAP_LENGTH), -1 class InvalidChordException(Exception): - r'''Exception class for suspect / invalid chord labels''' + r"""Exception class for suspect / invalid chord labels""" - def __init__(self, message='', chord_label=None): + def __init__(self, message="", chord_label=None): self.message = message self.chord_label = chord_label self.name = self.__class__.__name__ @@ -122,16 +122,15 @@ def __init__(self, message='', chord_label=None): # --- Chord Primitives --- def _pitch_classes(): - r'''Map from pitch class (str) to semitone (int).''' - pitch_classes = ['C', 'D', 'E', 'F', 'G', 'A', 'B'] + r"""Map from pitch class (str) to semitone (int).""" + pitch_classes = ["C", "D", "E", "F", "G", "A", "B"] semitones = [0, 2, 4, 5, 7, 9, 11] return dict([(c, s) for c, s in zip(pitch_classes, semitones)]) def _scale_degrees(): - r'''Mapping from scale degrees (str) to semitones (int).''' - degrees = ['1', '2', '3', '4', '5', '6', '7', - '8', '9', '10', '11', '12', '13'] + r"""Map scale degrees (str) to semitones (int).""" + degrees = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13"] semitones = [0, 2, 4, 5, 7, 9, 11, 12, 14, 16, 17, 19, 21] return dict([(d, s) for d, s in zip(degrees, semitones)]) @@ -141,7 +140,7 @@ def _scale_degrees(): def pitch_class_to_semitone(pitch_class): - r'''Convert a pitch class to semitone. + r"""Convert a pitch class to semitone. Parameters ---------- @@ -153,18 +152,19 @@ def pitch_class_to_semitone(pitch_class): semitone : int Semitone value of the pitch class. - ''' + """ semitone = 0 for idx, char in enumerate(pitch_class): - if char == '#' and idx > 0: + if char == "#" and idx > 0: semitone += 1 - elif char == 'b' and idx > 0: + elif char == "b" and idx > 0: semitone -= 1 elif idx == 0: semitone = PITCH_CLASSES.get(char) else: raise InvalidChordException( - "Pitch class improperly formed: %s" % pitch_class) + "Pitch class improperly formed: %s" % pitch_class + ) return semitone % 12 @@ -177,7 +177,7 @@ def scale_degree_to_semitone(scale_degree): Parameters ---------- - scale degree : str + scale_degree : str Spelling of a relative scale degree, e.g. 'b3', '7', '#5' Returns @@ -194,15 +194,17 @@ def scale_degree_to_semitone(scale_degree): if scale_degree.startswith("#"): offset = scale_degree.count("#") scale_degree = scale_degree.strip("#") - elif scale_degree.startswith('b'): + elif scale_degree.startswith("b"): offset = -1 * scale_degree.count("b") scale_degree = scale_degree.strip("b") semitone = SCALE_DEGREES.get(scale_degree, None) if semitone is None: raise InvalidChordException( - "Scale degree improperly formed: {}, expected one of {}." - .format(scale_degree, list(SCALE_DEGREES.keys()))) + "Scale degree improperly formed: {}, expected one of {}.".format( + scale_degree, list(SCALE_DEGREES.keys()) + ) + ) return semitone + offset @@ -242,35 +244,36 @@ def scale_degree_to_bitmap(scale_degree, modulo=False, length=BITMAP_LENGTH): # semitones, i.e. vector[0] is the tonic. QUALITIES = { # 1 2 3 4 5 6 7 - 'maj': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], - 'min': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0], - 'aug': [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], - 'dim': [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0], - 'sus4': [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0], - 'sus2': [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], - '7': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - 'maj7': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1], - 'min7': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], - 'minmaj7': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1], - 'maj6': [1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0], - 'min6': [1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0], - 'dim7': [1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0], - 'hdim7': [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0], - 'maj9': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1], - 'min9': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], - '9': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - 'b9': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - '#9': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - 'min11': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], - '11': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - '#11': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - 'maj13': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1], - 'min13': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], - '13': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - 'b13': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], - '1': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - '5': [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], - '': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} + "maj": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + "min": [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0], + "aug": [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + "dim": [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0], + "sus4": [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0], + "sus2": [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + "7": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + "maj7": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1], + "min7": [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], + "minmaj7": [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1], + "maj6": [1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0], + "min6": [1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0], + "dim7": [1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0], + "hdim7": [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0], + "maj9": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1], + "min9": [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], + "9": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + "b9": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + "#9": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + "min11": [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], + "11": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + "#11": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + "maj13": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1], + "min13": [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], + "13": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + "b13": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + "1": [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "5": [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + "": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +} def quality_to_bitmap(quality): @@ -290,7 +293,8 @@ def quality_to_bitmap(quality): if quality not in QUALITIES: raise InvalidChordException( "Unsupported chord quality shorthand: '%s' " - "Did you mean to reduce extended chords?" % quality) + "Did you mean to reduce extended chords?" % quality + ) return np.array(QUALITIES[quality]) @@ -299,19 +303,20 @@ def quality_to_bitmap(quality): # TODO(ejhumphrey): Revisit how minmaj7's are mapped. This is how TMC did it, # but MMV handles it like a separate quality (rather than an add7). EXTENDED_QUALITY_REDUX = { - 'minmaj7': ('min', set(['7'])), - 'maj9': ('maj7', set(['9'])), - 'min9': ('min7', set(['9'])), - '9': ('7', set(['9'])), - 'b9': ('7', set(['b9'])), - '#9': ('7', set(['#9'])), - '11': ('7', set(['9', '11'])), - '#11': ('7', set(['9', '#11'])), - '13': ('7', set(['9', '11', '13'])), - 'b13': ('7', set(['9', '11', 'b13'])), - 'min11': ('min7', set(['9', '11'])), - 'maj13': ('maj7', set(['9', '11', '13'])), - 'min13': ('min7', set(['9', '11', '13']))} + "minmaj7": ("min", set(["7"])), + "maj9": ("maj7", set(["9"])), + "min9": ("min7", set(["9"])), + "9": ("7", set(["9"])), + "b9": ("7", set(["b9"])), + "#9": ("7", set(["#9"])), + "11": ("7", set(["9", "11"])), + "#11": ("7", set(["9", "#11"])), + "13": ("7", set(["9", "11", "13"])), + "b13": ("7", set(["9", "11", "b13"])), + "min11": ("min7", set(["9", "11"])), + "maj13": ("maj7", set(["9", "11", "13"])), + "min13": ("min7", set(["9", "11", "13"])), +} def reduce_extended_quality(quality): @@ -340,20 +345,19 @@ def validate_chord_label(chord_label): Parameters ---------- - chord : str + chord_label : str Chord label to validate. - """ - # This monster regexp is pulled from the JAMS chord namespace, # which is in turn derived from the context-free grammar of # Harte et al., 2005. - pattern = re.compile(r'''^((N|X)|(([A-G](b*|#*))((:(maj|min|dim|aug|1|5|sus2|sus4|maj6|min6|7|maj7|min7|dim7|hdim7|minmaj7|aug7|9|maj9|min9|11|maj11|min11|13|maj13|min13)(\((\*?((b*|#*)([1-9]|1[0-3]?))(,\*?((b*|#*)([1-9]|1[0-3]?)))*)\))?)|(:\((\*?((b*|#*)([1-9]|1[0-3]?))(,\*?((b*|#*)([1-9]|1[0-3]?)))*)\)))?((/((b*|#*)([1-9]|1[0-3]?)))?)?))$''') # nopep8 + pattern = re.compile( + r"""^((N|X)|(([A-G](b*|#*))((:(maj|min|dim|aug|1|5|sus2|sus4|maj6|min6|7|maj7|min7|dim7|hdim7|minmaj7|aug7|9|maj9|min9|11|maj11|min11|13|maj13|min13)(\((\*?((b*|#*)([1-9]|1[0-3]?))(,\*?((b*|#*)([1-9]|1[0-3]?)))*)\))?)|(:\((\*?((b*|#*)([1-9]|1[0-3]?))(,\*?((b*|#*)([1-9]|1[0-3]?)))*)\)))?((/((b*|#*)([1-9]|1[0-3]?)))?)?))$""" + ) # nopep8 if not pattern.match(chord_label): - raise InvalidChordException('Invalid chord label: ' - '{}'.format(chord_label)) + raise InvalidChordException("Invalid chord label: " "{}".format(chord_label)) pass @@ -392,9 +396,9 @@ def split(chord_label, reduce_extended_chords=False): chord_label = str(chord_label) validate_chord_label(chord_label) if chord_label == NO_CHORD: - return [chord_label, '', set(), ''] + return [chord_label, "", set(), ""] - bass = '1' + bass = "1" if "/" in chord_label: chord_label, bass = chord_label.split("/") @@ -413,8 +417,9 @@ def split(chord_label, reduce_extended_chords=False): # Intervals specifying omissions MUST have a quality. if omission and ":" not in chord_label: raise InvalidChordException( - "Intervals specifying omissions MUST have a quality.") - quality = '' if scale_degrees else 'maj' + "Intervals specifying omissions MUST have a quality." + ) + quality = "" if scale_degrees else "maj" if ":" in chord_label: chord_root, quality_name = chord_label.split(":") # Extended chords (with ":"s) may not explicitly have Major qualities, @@ -431,7 +436,7 @@ def split(chord_label, reduce_extended_chords=False): return [chord_root, quality, scale_degrees, bass] -def join(chord_root, quality='', extensions=None, bass=''): +def join(chord_root, quality="", extensions=None, bass=""): r"""Join the parts of a chord into a complete chord label. Parameters @@ -459,15 +464,14 @@ def join(chord_root, quality='', extensions=None, bass=''): chord_label += ":%s" % quality if extensions: chord_label += "(%s)" % ",".join(extensions) - if bass and bass != '1': + if bass and bass != "1": chord_label += "/%s" % bass validate_chord_label(chord_label) return chord_label # --- Chords to Numerical Representations --- -def encode(chord_label, reduce_extended_chords=False, - strict_bass_intervals=False): +def encode(chord_label, reduce_extended_chords=False, strict_bass_intervals=False): """Translate a chord label to numerical representations for evaluation. Parameters @@ -490,15 +494,14 @@ def encode(chord_label, reduce_extended_chords=False, 12-dim vector of relative semitones in the chord spelling. bass_number : int Relative semitone of the chord's bass note, e.g. 0=root, 7=fifth, etc. - """ - if chord_label == NO_CHORD: return NO_CHORD_ENCODED if chord_label == X_CHORD: return X_CHORD_ENCODED chord_root, quality, scale_degrees, bass = split( - chord_label, reduce_extended_chords=reduce_extended_chords) + chord_label, reduce_extended_chords=reduce_extended_chords + ) root_number = pitch_class_to_semitone(chord_root) bass_number = scale_degree_to_semitone(bass) % 12 @@ -507,14 +510,14 @@ def encode(chord_label, reduce_extended_chords=False, semitone_bitmap[0] = 1 for scale_degree in scale_degrees: - semitone_bitmap += scale_degree_to_bitmap(scale_degree, - reduce_extended_chords) + semitone_bitmap += scale_degree_to_bitmap(scale_degree, reduce_extended_chords) semitone_bitmap = (semitone_bitmap > 0).astype(np.int64) if not semitone_bitmap[bass_number] and strict_bass_intervals: raise InvalidChordException( - "Given bass scale degree is absent from this chord: " - "%s" % chord_label, chord_label) + "Given bass scale degree is absent from this chord: " "%s" % chord_label, + chord_label, + ) else: semitone_bitmap[bass_number] = 1 return root_number, semitone_bitmap, bass_number @@ -557,7 +560,7 @@ def encode_many(chord_labels, reduce_extended_chords=False): def rotate_bitmap_to_root(bitmap, chord_root): - """Circularly shift a relative bitmap to its asbolute pitch classes. + """Circularly shift a relative bitmap to its absolute pitch classes. For clarity, the best explanation is an example. Given 'G:Maj', the root and quality map are as follows:: @@ -592,15 +595,15 @@ def rotate_bitmap_to_root(bitmap, chord_root): def rotate_bitmaps_to_roots(bitmaps, roots): - """Circularly shift a relative bitmaps to asbolute pitch classes. + """Circularly shift a relative bitmaps to absolute pitch classes. See :func:`rotate_bitmap_to_root` for more information. Parameters ---------- - bitmap : np.ndarray, shape=(N, 12) + bitmaps : np.ndarray, shape=(N, 12) Bitmap of active notes, relative to the given root. - root : np.ndarray, shape=(N,) + roots : np.ndarray, shape=(N,) Absolute pitch class number. Returns @@ -617,7 +620,7 @@ def rotate_bitmaps_to_roots(bitmaps, roots): # --- Comparison Routines --- def validate(reference_labels, estimated_labels): - """Checks that the input annotations to a comparison function look like + """Check that the input annotations to a comparison function look like valid chord labels. Parameters @@ -626,22 +629,22 @@ def validate(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. - """ N = len(reference_labels) M = len(estimated_labels) if N != M: raise ValueError( "Chord comparison received different length lists: " - "len(reference)=%d\tlen(estimates)=%d" % (N, M)) + "len(reference)=%d\tlen(estimates)=%d" % (N, M) + ) for labels in [reference_labels, estimated_labels]: for chord_label in labels: validate_chord_label(chord_label) # When either label list is empty, warn the user if len(reference_labels) == 0: - warnings.warn('Reference labels are empty') + warnings.warn("Reference labels are empty") if len(estimated_labels) == 0: - warnings.warn('Estimated labels are empty') + warnings.warn("Estimated labels are empty") def weighted_accuracy(comparisons, weights): @@ -684,29 +687,32 @@ def weighted_accuracy(comparisons, weights): N = len(comparisons) # There should be as many weights as comparisons if weights.shape[0] != N: - raise ValueError('weights and comparisons should be of the same' - ' length. len(weights) = {} but len(comparisons)' - ' = {}'.format(weights.shape[0], N)) + raise ValueError( + "weights and comparisons should be of the same" + " length. len(weights) = {} but len(comparisons)" + " = {}".format(weights.shape[0], N) + ) if (weights < 0).any(): - raise ValueError('Weights should all be positive.') + raise ValueError("Weights should all be positive.") if np.sum(weights) == 0: - warnings.warn('No nonzero weights, returning 0') + warnings.warn("No nonzero weights, returning 0") return 0 # Find all comparison scores which are valid - valid_idx = (comparisons >= 0) + valid_idx = comparisons >= 0 # If no comparable chords were provided, warn and return 0 if valid_idx.sum() == 0: - warnings.warn("No reference chords were comparable " - "to estimated chords, returning 0.") + warnings.warn( + "No reference chords were comparable " "to estimated chords, returning 0." + ) return 0 # Remove any uncomparable labels comparisons = comparisons[valid_idx] weights = weights[valid_idx] # Normalize the weights total_weight = float(np.sum(weights)) - normalized_weights = np.asarray(weights, dtype=float)/total_weight + normalized_weights = np.asarray(weights, dtype=float) / total_weight # Score is the sum of all weighted comparisons - return np.sum(comparisons*normalized_weights) + return np.sum(comparisons * normalized_weights) def thirds(reference_labels, estimated_labels): @@ -843,8 +849,7 @@ def triads(reference_labels, estimated_labels): est_roots, est_semitones = encode_many(estimated_labels, False)[:2] eq_roots = ref_roots == est_roots - eq_semitones = np.all( - np.equal(ref_semitones[:, :8], est_semitones[:, :8]), axis=1) + eq_semitones = np.all(np.equal(ref_semitones[:, :8], est_semitones[:, :8]), axis=1) comparison_scores = (eq_roots * eq_semitones).astype(np.float64) # Ignore 'X' chords @@ -892,8 +897,7 @@ def triads_inv(reference_labels, estimated_labels): eq_roots = ref_roots == est_roots eq_basses = ref_bass == est_bass - eq_semitones = np.all( - np.equal(ref_semitones[:, :8], est_semitones[:, :8]), axis=1) + eq_semitones = np.all(np.equal(ref_semitones[:, :8], est_semitones[:, :8]), axis=1) comparison_scores = (eq_roots * eq_semitones * eq_basses).astype(np.float64) # Ignore 'X' chords @@ -1029,9 +1033,7 @@ def root(reference_labels, estimated_labels): comparison_scores : np.ndarray, shape=(n,), dtype=float Comparison scores, in [0.0, 1.0], or -1 if the comparison is out of gamut. - """ - validate(reference_labels, estimated_labels) ref_roots, ref_semitones = encode_many(reference_labels, False)[:2] est_roots = encode_many(estimated_labels, False)[0] @@ -1096,8 +1098,9 @@ def mirex(reference_labels, estimated_labels): # Skip chords where the number of active semitones `n` is # 0 < n < `min_intersection`. ref_semitone_count = (ref_data[1] > 0).sum(axis=1) - skip_idx = np.logical_and(ref_semitone_count > 0, - ref_semitone_count < min_intersection) + skip_idx = np.logical_and( + ref_semitone_count > 0, ref_semitone_count < min_intersection + ) # Also ignore 'X' chords. np.logical_or(skip_idx, np.any(ref_data[1] < 0, axis=1), skip_idx) comparison_scores[skip_idx] = -1.0 @@ -1141,15 +1144,14 @@ def majmin(reference_labels, estimated_labels): """ validate(reference_labels, estimated_labels) - maj_semitones = np.array(QUALITIES['maj'][:8]) - min_semitones = np.array(QUALITIES['min'][:8]) + maj_semitones = np.array(QUALITIES["maj"][:8]) + min_semitones = np.array(QUALITIES["min"][:8]) ref_roots, ref_semitones, _ = encode_many(reference_labels, False) est_roots, est_semitones, _ = encode_many(estimated_labels, False) eq_root = ref_roots == est_roots - eq_quality = np.all(np.equal(ref_semitones[:, :8], - est_semitones[:, :8]), axis=1) + eq_quality = np.all(np.equal(ref_semitones[:, :8], est_semitones[:, :8]), axis=1) comparison_scores = (eq_root * eq_quality).astype(np.float64) # Test for Major / Minor / No-chord @@ -1208,15 +1210,14 @@ def majmin_inv(reference_labels, estimated_labels): """ validate(reference_labels, estimated_labels) - maj_semitones = np.array(QUALITIES['maj'][:8]) - min_semitones = np.array(QUALITIES['min'][:8]) + maj_semitones = np.array(QUALITIES["maj"][:8]) + min_semitones = np.array(QUALITIES["min"][:8]) ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False) est_roots, est_semitones, est_bass = encode_many(estimated_labels, False) eq_root_bass = (ref_roots == est_roots) * (ref_bass == est_bass) - eq_semitones = np.all(np.equal(ref_semitones[:, :8], - est_semitones[:, :8]), axis=1) + eq_semitones = np.all(np.equal(ref_semitones[:, :8], est_semitones[:, :8]), axis=1) comparison_scores = (eq_root_bass * eq_semitones).astype(np.float64) # Test for Major / Minor / No-chord @@ -1272,7 +1273,7 @@ def sevenths(reference_labels, estimated_labels): """ validate(reference_labels, estimated_labels) - seventh_qualities = ['maj', 'min', 'maj7', '7', 'min7', ''] + seventh_qualities = ["maj", "min", "maj7", "7", "min7", ""] valid_semitones = np.array([QUALITIES[name] for name in seventh_qualities]) ref_roots, ref_semitones = encode_many(reference_labels, False)[:2] @@ -1283,8 +1284,12 @@ def sevenths(reference_labels, estimated_labels): comparison_scores = (eq_root * eq_semitones).astype(np.float64) # Test for reference chord inclusion - is_valid = np.array([np.all(np.equal(ref_semitones, semitones), axis=1) - for semitones in valid_semitones]) + is_valid = np.array( + [ + np.all(np.equal(ref_semitones, semitones), axis=1) + for semitones in valid_semitones + ] + ) # Drop if NOR comparison_scores[np.sum(is_valid, axis=0) == 0] = -1 return comparison_scores @@ -1327,7 +1332,7 @@ def sevenths_inv(reference_labels, estimated_labels): """ validate(reference_labels, estimated_labels) - seventh_qualities = ['maj', 'min', 'maj7', '7', 'min7', ''] + seventh_qualities = ["maj", "min", "maj7", "7", "min7", ""] valid_semitones = np.array([QUALITIES[name] for name in seventh_qualities]) ref_roots, ref_semitones, ref_basses = encode_many(reference_labels, False) @@ -1338,8 +1343,12 @@ def sevenths_inv(reference_labels, estimated_labels): comparison_scores = (eq_roots_basses * eq_semitones).astype(np.float64) # Test for Major / Minor / No-chord - is_valid = np.array([np.all(np.equal(ref_semitones, semitones), axis=1) - for semitones in valid_semitones]) + is_valid = np.array( + [ + np.all(np.equal(ref_semitones, semitones), axis=1) + for semitones in valid_semitones + ] + ) comparison_scores[np.sum(is_valid, axis=0) == 0] = -1 # Disable inversions that are not part of the quality @@ -1384,12 +1393,14 @@ def directional_hamming_distance(reference_intervals, estimated_intervals): util.validate_intervals(reference_intervals) # make sure chord intervals do not overlap - if len(reference_intervals) > 1 and (reference_intervals[:-1, 1] > - reference_intervals[1:, 0]).any(): - raise ValueError('Chord Intervals must not overlap') + if ( + len(reference_intervals) > 1 + and (reference_intervals[:-1, 1] > reference_intervals[1:, 0]).any() + ): + raise ValueError("Chord Intervals must not overlap") est_ts = np.unique(estimated_intervals.flatten()) - seg = 0. + seg = 0.0 for start, end in reference_intervals: dur = end - start between_start_end = est_ts[(est_ts >= start) & (est_ts < end)] @@ -1421,8 +1432,7 @@ def overseg(reference_intervals, estimated_intervals): oversegmentation score : float Comparison score, in [0.0, 1.0], where 1.0 means no oversegmentation. """ - return 1 - directional_hamming_distance(reference_intervals, - estimated_intervals) + return 1 - directional_hamming_distance(reference_intervals, estimated_intervals) def underseg(reference_intervals, estimated_intervals): @@ -1448,8 +1458,7 @@ def underseg(reference_intervals, estimated_intervals): undersegmentation score : float Comparison score, in [0.0, 1.0], where 1.0 means no undersegmentation. """ - return 1 - directional_hamming_distance(estimated_intervals, - reference_intervals) + return 1 - directional_hamming_distance(estimated_intervals, reference_intervals) def seg(reference_intervals, estimated_intervals): @@ -1475,9 +1484,10 @@ def seg(reference_intervals, estimated_intervals): segmentation score : float Comparison score, in [0.0, 1.0], where 1.0 means perfect segmentation. """ - - return min(underseg(reference_intervals, estimated_intervals), - overseg(reference_intervals, estimated_intervals)) + return min( + underseg(reference_intervals, estimated_intervals), + overseg(reference_intervals, estimated_intervals), + ) def merge_chord_intervals(intervals, labels): @@ -1504,8 +1514,9 @@ def merge_chord_intervals(intervals, labels): prev_rt = None prev_st = None prev_ba = None - for s, e, rt, st, ba in zip(intervals[:, 0], intervals[:, 1], - roots, semitones, basses): + for s, e, rt, st, ba in zip( + intervals[:, 0], intervals[:, 1], roots, semitones, basses + ): if rt != prev_rt or (st != prev_st).any() or ba != prev_ba: prev_rt, prev_st, prev_ba = rt, st, ba merged_ivs.append([s, e]) @@ -1515,7 +1526,7 @@ def merge_chord_intervals(intervals, labels): def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): - """Computes weighted accuracy for all comparison functions for the given + """Compute weighted accuracy for all comparison functions for the given reference and estimated annotations. Examples @@ -1532,20 +1543,16 @@ def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): ref_intervals : np.ndarray, shape=(n, 2) Reference chord intervals, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. - ref_labels : list, shape=(n,) reference chord labels, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. - est_intervals : np.ndarray, shape=(m, 2) estimated chord intervals, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. - est_labels : list, shape=(m,) estimated chord labels, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. - - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -1558,47 +1565,50 @@ def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): """ # Append or crop estimated intervals so their span is the same as reference est_intervals, est_labels = util.adjust_intervals( - est_intervals, est_labels, ref_intervals.min(), ref_intervals.max(), - NO_CHORD, NO_CHORD) + est_intervals, + est_labels, + ref_intervals.min(), + ref_intervals.max(), + NO_CHORD, + NO_CHORD, + ) # use merged intervals for segmentation evaluation merged_ref_intervals = merge_chord_intervals(ref_intervals, ref_labels) merged_est_intervals = merge_chord_intervals(est_intervals, est_labels) # Adjust the labels so that they span the same intervals intervals, ref_labels, est_labels = util.merge_labeled_intervals( - ref_intervals, ref_labels, est_intervals, est_labels) + ref_intervals, ref_labels, est_intervals, est_labels + ) # Convert intervals to durations (used as weights) durations = util.intervals_to_durations(intervals) # Store scores for each comparison function scores = collections.OrderedDict() - scores['thirds'] = weighted_accuracy(thirds(ref_labels, est_labels), - durations) - scores['thirds_inv'] = weighted_accuracy(thirds_inv(ref_labels, - est_labels), durations) - scores['triads'] = weighted_accuracy(triads(ref_labels, est_labels), - durations) - scores['triads_inv'] = weighted_accuracy(triads_inv(ref_labels, - est_labels), durations) - scores['tetrads'] = weighted_accuracy(tetrads(ref_labels, est_labels), - durations) - scores['tetrads_inv'] = weighted_accuracy(tetrads_inv(ref_labels, - est_labels), - durations) - scores['root'] = weighted_accuracy(root(ref_labels, est_labels), durations) - scores['mirex'] = weighted_accuracy(mirex(ref_labels, est_labels), - durations) - scores['majmin'] = weighted_accuracy(majmin(ref_labels, est_labels), - durations) - scores['majmin_inv'] = weighted_accuracy(majmin_inv(ref_labels, - est_labels), durations) - scores['sevenths'] = weighted_accuracy(sevenths(ref_labels, est_labels), - durations) - scores['sevenths_inv'] = weighted_accuracy(sevenths_inv(ref_labels, - est_labels), - durations) - scores['underseg'] = underseg(merged_ref_intervals, merged_est_intervals) - scores['overseg'] = overseg(merged_ref_intervals, merged_est_intervals) - scores['seg'] = min(scores['overseg'], scores['underseg']) + scores["thirds"] = weighted_accuracy(thirds(ref_labels, est_labels), durations) + scores["thirds_inv"] = weighted_accuracy( + thirds_inv(ref_labels, est_labels), durations + ) + scores["triads"] = weighted_accuracy(triads(ref_labels, est_labels), durations) + scores["triads_inv"] = weighted_accuracy( + triads_inv(ref_labels, est_labels), durations + ) + scores["tetrads"] = weighted_accuracy(tetrads(ref_labels, est_labels), durations) + scores["tetrads_inv"] = weighted_accuracy( + tetrads_inv(ref_labels, est_labels), durations + ) + scores["root"] = weighted_accuracy(root(ref_labels, est_labels), durations) + scores["mirex"] = weighted_accuracy(mirex(ref_labels, est_labels), durations) + scores["majmin"] = weighted_accuracy(majmin(ref_labels, est_labels), durations) + scores["majmin_inv"] = weighted_accuracy( + majmin_inv(ref_labels, est_labels), durations + ) + scores["sevenths"] = weighted_accuracy(sevenths(ref_labels, est_labels), durations) + scores["sevenths_inv"] = weighted_accuracy( + sevenths_inv(ref_labels, est_labels), durations + ) + scores["underseg"] = underseg(merged_ref_intervals, merged_est_intervals) + scores["overseg"] = overseg(merged_ref_intervals, merged_est_intervals) + scores["seg"] = min(scores["overseg"], scores["underseg"]) return scores diff --git a/mir_eval/display.py b/mir_eval/display.py index 39187b84..a67ae326 100644 --- a/mir_eval/display.py +++ b/mir_eval/display.py @@ -1,5 +1,5 @@ # -*- encoding: utf-8 -*- -'''Display functions''' +"""Display functions""" from collections import defaultdict @@ -16,15 +16,14 @@ from .util import midi_to_hz, hz_to_midi -def __expand_limits(ax, limits, which='x'): - '''Helper function to expand axis limits''' - - if which == 'x': +def __expand_limits(ax, limits, which="x"): + """Expand axis limits""" + if which == "x": getter, setter = ax.get_xlim, ax.set_xlim - elif which == 'y': + elif which == "y": getter, setter = ax.get_ylim, ax.set_ylim else: - raise ValueError('invalid axis: {}'.format(which)) + raise ValueError("invalid axis: {}".format(which)) old_lims = getter() new_lims = list(limits) @@ -40,7 +39,7 @@ def __expand_limits(ax, limits, which='x'): def __get_axes(ax=None, fig=None): - '''Get or construct the target axes object for a new plot. + """Get or construct the target axes object for a new plot. Parameters ---------- @@ -57,13 +56,10 @@ def __get_axes(ax=None, fig=None): ax : matplotlib.pyplot.axes An axis handle on which to draw the segmentation. If none is provided, a new set of axes is created. - new_axes : bool If `True`, the axis object was newly constructed. If `False`, the axis object already existed. - - ''' - + """ new_axes = False if ax is not None: @@ -71,6 +67,7 @@ def __get_axes(ax=None, fig=None): if fig is None: import matplotlib.pyplot as plt + fig = plt.gcf() if not fig.get_axes(): @@ -79,9 +76,17 @@ def __get_axes(ax=None, fig=None): return fig.gca(), new_axes -def segments(intervals, labels, base=None, height=None, text=False, - text_kw=None, ax=None, **kwargs): - '''Plot a segmentation as a set of disjoint rectangles. +def segments( + intervals, + labels, + base=None, + height=None, + text=False, + text_kw=None, + ax=None, + **kwargs +): + """Plot a segmentation as a set of disjoint rectangles. Parameters ---------- @@ -89,33 +94,26 @@ def segments(intervals, labels, base=None, height=None, text=False, segment intervals, in the format returned by :func:`mir_eval.io.load_intervals` or :func:`mir_eval.io.load_labeled_intervals`. - labels : list, shape=(n,) reference segment labels, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. - base : number The vertical position of the base of the rectangles. By default, this will be the bottom of the plot. - height : number The height of the rectangles. By default, this will be the top of the plot (minus ``base``). - text : bool If true, each segment's label is displayed in its upper-left corner - text_kw : dict If ``text == True``, the properties of the text object can be specified here. See ``matplotlib.pyplot.Text`` for valid parameters - ax : matplotlib.pyplot.axes An axis handle on which to draw the segmentation. If none is provided, a new set of axes is created. - - kwargs + **kwargs Additional keyword arguments to pass to ``matplotlib.patches.Rectangle``. @@ -123,12 +121,12 @@ def segments(intervals, labels, base=None, height=None, text=False, ------- ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes - ''' + """ if text_kw is None: text_kw = dict() - text_kw.setdefault('va', 'top') - text_kw.setdefault('clip_on', True) - text_kw.setdefault('bbox', dict(boxstyle='round', facecolor='white')) + text_kw.setdefault("va", "top") + text_kw.setdefault("clip_on", True) + text_kw.setdefault("bbox", dict(boxstyle="round", facecolor="white")) # Make sure we have a numpy array intervals = np.atleast_2d(intervals) @@ -159,21 +157,24 @@ def segments(intervals, labels, base=None, height=None, text=False, seg_map[lab] = seg_def_style.copy() seg_map[lab].update(style) # Swap color -> facecolor here so we preserve edgecolor on rects - seg_map[lab]['facecolor'] = seg_map[lab].pop('color') + seg_map[lab]["facecolor"] = seg_map[lab].pop("color") seg_map[lab].update(kwargs) - seg_map[lab]['label'] = lab + seg_map[lab]["label"] = lab for ival, lab in zip(intervals, labels): - rect = Rectangle((ival[0], base), ival[1] - ival[0], height, - **seg_map[lab]) + rect = Rectangle((ival[0], base), ival[1] - ival[0], height, **seg_map[lab]) ax.add_patch(rect) - seg_map[lab].pop('label', None) + seg_map[lab].pop("label", None) if text: - ann = ax.annotate(lab, - xy=(ival[0], height), xycoords='data', - xytext=(8, -10), textcoords='offset points', - **text_kw) + ann = ax.annotate( + lab, + xy=(ival[0], height), + xycoords="data", + xytext=(8, -10), + textcoords="offset points", + **text_kw + ) ann.set_clip_path(rect) if new_axes: @@ -181,15 +182,23 @@ def segments(intervals, labels, base=None, height=None, text=False, # Only expand if we have data if intervals.size: - __expand_limits(ax, [intervals.min(), intervals.max()], which='x') + __expand_limits(ax, [intervals.min(), intervals.max()], which="x") return ax -def labeled_intervals(intervals, labels, label_set=None, - base=None, height=None, extend_labels=True, - ax=None, tick=True, **kwargs): - '''Plot labeled intervals with each label on its own row. +def labeled_intervals( + intervals, + labels, + label_set=None, + base=None, + height=None, + extend_labels=True, + ax=None, + tick=True, + **kwargs +): + """Plot labeled intervals with each label on its own row. Parameters ---------- @@ -235,7 +244,7 @@ def labeled_intervals(intervals, labels, label_set=None, tick : bool If ``True``, sets tick positions and labels on the y-axis. - kwargs + **kwargs Additional keyword arguments to pass to `matplotlib.collection.BrokenBarHCollection`. @@ -243,8 +252,7 @@ def labeled_intervals(intervals, labels, label_set=None, ------- ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes - ''' - + """ # Get the axes handle ax, _ = __get_axes(ax=ax) @@ -272,7 +280,7 @@ def labeled_intervals(intervals, labels, label_set=None, style.update(next(ax._get_patches_for_fill.prop_cycler)) # Swap color -> facecolor here so we preserve edgecolor on rects - style['facecolor'] = style.pop('color') + style["facecolor"] = style.pop("color") style.update(kwargs) if base is None: @@ -295,33 +303,32 @@ def labeled_intervals(intervals, labels, label_set=None, xvals[lab].append((ival[0], ival[1] - ival[0])) for lab in seg_y: - ax.add_collection(BrokenBarHCollection(xvals[lab], seg_y[lab], - **style)) + ax.add_collection(BrokenBarHCollection(xvals[lab], seg_y[lab], **style)) # Pop the label after the first time we see it, so we only get # one legend entry - style.pop('label', None) + style.pop("label", None) # Draw a line separating the new labels from pre-existing labels if label_set != ticks: - ax.axhline(len(label_set), color='k', alpha=0.5) + ax.axhline(len(label_set), color="k", alpha=0.5) if tick: - ax.grid(True, axis='y') + ax.grid(True, axis="y") ax.set_yticks([]) ax.set_yticks(base) - ax.set_yticklabels(ticks, va='bottom') + ax.set_yticklabels(ticks, va="bottom") ax.yaxis.set_major_formatter(IntervalFormatter(base, ticks)) if base.size: - __expand_limits(ax, [base.min(), (base + height).max()], which='y') + __expand_limits(ax, [base.min(), (base + height).max()], which="y") if intervals.size: - __expand_limits(ax, [intervals.min(), intervals.max()], which='x') + __expand_limits(ax, [intervals.min(), intervals.max()], which="x") return ax class IntervalFormatter(Formatter): - '''Ticker formatter for labeled interval plots. + """Ticker formatter for labeled interval plots. Parameters ---------- @@ -330,18 +337,18 @@ class IntervalFormatter(Formatter): ticks : array-like of string The labels for the ticks - ''' - def __init__(self, base, ticks): + """ + def __init__(self, base, ticks): self._map = {int(k): v for k, v in zip(base, ticks)} def __call__(self, x, pos=None): - - return self._map.get(int(x), '') + """Map the input position to its corresponding interval label""" + return self._map.get(int(x), "") def hierarchy(intervals_hier, labels_hier, levels=None, ax=None, **kwargs): - '''Plot a hierarchical segmentation + """Plot a hierarchical segmentation Parameters ---------- @@ -351,25 +358,24 @@ def hierarchy(intervals_hier, labels_hier, levels=None, ax=None, **kwargs): :func:`mir_eval.io.load_intervals` or :func:`mir_eval.io.load_labeled_intervals`. Segmentations should be ordered by increasing specificity. - labels_hier : list of list-like A list of segmentation labels. Each element should be a list of labels for the corresponding element in `intervals_hier`. - levels : list of string Each element ``levels[i]`` is a label for the ```i`` th segmentation. This is used in the legend to denote the levels in a segment hierarchy. - - kwargs + ax : matplotlib.pyplot.axes + An axis handle on which to draw the intervals. + If none is provided, a new set of axes is created. + **kwargs Additional keyword arguments to `labeled_intervals`. Returns ------- ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes - ''' - + """ # This will break if a segment label exists in multiple levels if levels is None: levels = list(range(len(intervals_hier))) @@ -380,9 +386,7 @@ def hierarchy(intervals_hier, labels_hier, levels=None, ax=None, **kwargs): # Count the pre-existing patches n_patches = len(ax.patches) - for ints, labs, key in zip(intervals_hier[::-1], - labels_hier[::-1], - levels[::-1]): + for ints, labs, key in zip(intervals_hier[::-1], labels_hier[::-1], levels[::-1]): labeled_intervals(ints, labs, label=key, ax=ax, **kwargs) # Reverse the patch ordering for anything we've added. @@ -391,9 +395,8 @@ def hierarchy(intervals_hier, labels_hier, levels=None, ax=None, **kwargs): return ax -def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, - **kwargs): - '''Plot event times as a set of vertical lines +def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, **kwargs): + """Plot event times as a set of vertical lines Parameters ---------- @@ -401,29 +404,23 @@ def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, event times, in the format returned by :func:`mir_eval.io.load_events` or :func:`mir_eval.io.load_labeled_events`. - labels : list, shape=(n,), optional event labels, in the format returned by :func:`mir_eval.io.load_labeled_events`. - base : number The vertical position of the base of the line. By default, this will be the bottom of the plot. - height : number The height of the lines. By default, this will be the top of the plot (minus `base`). - ax : matplotlib.pyplot.axes An axis handle on which to draw the segmentation. If none is provided, a new set of axes is created. - text_kw : dict If `labels` is provided, the properties of the text objects can be specified here. See `matplotlib.pyplot.Text` for valid parameters - - kwargs + **kwargs Additional keyword arguments to pass to `matplotlib.pyplot.vlines`. @@ -431,12 +428,12 @@ def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, ------- ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes - ''' + """ if text_kw is None: text_kw = dict() - text_kw.setdefault('va', 'top') - text_kw.setdefault('clip_on', True) - text_kw.setdefault('bbox', dict(boxstyle='round', facecolor='white')) + text_kw.setdefault("va", "top") + text_kw.setdefault("clip_on", True) + text_kw.setdefault("bbox", dict(boxstyle="round", facecolor="white")) # make sure we have an array for times times = np.asarray(times) @@ -466,32 +463,35 @@ def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, style = next(cycler).copy() style.update(kwargs) # If the user provided 'colors', don't override it with 'color' - if 'colors' in style: - style.pop('color', None) + if "colors" in style: + style.pop("color", None) lines = ax.vlines(times, base, base + height, **style) if labels: for path, lab in zip(lines.get_paths(), labels): - ax.annotate(lab, - xy=(path.vertices[0][0], height), - xycoords='data', - xytext=(8, -10), textcoords='offset points', - **text_kw) + ax.annotate( + lab, + xy=(path.vertices[0][0], height), + xycoords="data", + xytext=(8, -10), + textcoords="offset points", + **text_kw + ) if new_axes: ax.set_yticks([]) - __expand_limits(ax, [base, base + height], which='y') + __expand_limits(ax, [base, base + height], which="y") if times.size: - __expand_limits(ax, [times.min(), times.max()], which='x') + __expand_limits(ax, [times.min(), times.max()], which="x") return ax def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): - '''Visualize pitch contours + """Visualize pitch contours Parameters ---------- @@ -517,22 +517,20 @@ def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): An axis handle on which to draw the pitch contours. If none is provided, a new set of axes is created. - kwargs + **kwargs Additional keyword arguments to `matplotlib.pyplot.plot`. Returns ------- ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes - ''' - + """ ax, _ = __get_axes(ax=ax) times = np.asarray(times) # First, segment into contiguously voiced contours - frequencies, voicings = freq_to_voicing(np.asarray(frequencies, - dtype=np.float64)) + frequencies, voicings = freq_to_voicing(np.asarray(frequencies, dtype=np.float64)) voicings = voicings.astype(bool) # Here are all the change-points @@ -564,20 +562,19 @@ def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): for idx in v_slices: ax.plot(times[idx], frequencies[idx], **style) - style.pop('label', None) + style.pop("label", None) # Plot the unvoiced portions if unvoiced: - style['alpha'] = style.get('alpha', 1.0) * 0.5 + style["alpha"] = style.get("alpha", 1.0) * 0.5 for idx in u_slices: ax.plot(times[idx], frequencies[idx], **style) return ax -def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None, - **kwargs): - '''Visualize multiple f0 measurements +def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): + """Visualize multiple f0 measurements Parameters ---------- @@ -606,15 +603,14 @@ def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None, An axis handle on which to draw the pitch contours. If none is provided, a new set of axes is created. - kwargs + **kwargs Additional keyword arguments to `plt.scatter`. Returns ------- ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes - ''' - + """ # Get the axes handle ax, _ = __get_axes(ax=ax) @@ -624,8 +620,8 @@ def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None, style_voiced.update(kwargs) style_unvoiced = style_voiced.copy() - style_unvoiced.pop('label', None) - style_unvoiced['alpha'] = style_unvoiced.get('alpha', 1.0) * 0.5 + style_unvoiced.pop("label", None) + style_unvoiced["alpha"] = style_unvoiced.get("alpha", 1.0) * 0.5 # We'll collect all times and frequencies first, then plot them voiced_times = [] @@ -668,7 +664,7 @@ def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None, def piano_roll(intervals, pitches=None, midi=None, ax=None, **kwargs): - '''Plot a quantized piano roll as intervals + """Plot a quantized piano roll as intervals Parameters ---------- @@ -687,66 +683,61 @@ def piano_roll(intervals, pitches=None, midi=None, ax=None, **kwargs): An axis handle on which to draw the intervals. If none is provided, a new set of axes is created. - kwargs + **kwargs Additional keyword arguments to :func:`labeled_intervals`. Returns ------- ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes - ''' - + """ if midi is None: if pitches is None: - raise ValueError('At least one of `midi` or `pitches` ' - 'must be provided.') + raise ValueError("At least one of `midi` or `pitches` " "must be provided.") midi = hz_to_midi(pitches) scale = np.arange(128) - ax = labeled_intervals(intervals, np.round(midi).astype(int), - label_set=scale, - tick=False, - ax=ax, - **kwargs) + ax = labeled_intervals( + intervals, + np.round(midi).astype(int), + label_set=scale, + tick=False, + ax=ax, + **kwargs + ) # Minor tick at each semitone ax.yaxis.set_minor_locator(MultipleLocator(1)) - ax.axis('auto') + ax.axis("auto") return ax def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs): - '''Source-separation visualization + """Source-separation visualization Parameters ---------- sources : np.ndarray, shape=(nsrc, nsampl) A list of waveform buffers corresponding to each source - fs : number > 0 The sampling rate - labels : list of strings An optional list of descriptors corresponding to each source - alpha : float in [0, 1] Maximum alpha (opacity) of spectrogram values. - ax : matplotlib.pyplot.axes An axis handle on which to draw the spectrograms. If none is provided, a new set of axes is created. - - kwargs + **kwargs Additional keyword arguments to ``scipy.signal.spectrogram`` Returns ------- ax The axis handle for this plot - ''' - + """ # Get the axes handle ax, new_axes = __get_axes(ax=ax) @@ -754,9 +745,9 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs): sources = np.atleast_2d(sources) if labels is None: - labels = ['Source {:d}'.format(_) for _ in range(len(sources))] + labels = ["Source {:d}".format(_) for _ in range(len(sources))] - kwargs.setdefault('scaling', 'spectrum') + kwargs.setdefault("scaling", "spectrum") # The cumulative spectrogram across sources # is used to establish the reference power @@ -778,40 +769,42 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs): color_conv = ColorConverter() for i, spec in enumerate(specs): - # For each source, grab a new color from the cycler # Then construct a colormap that interpolates from # [transparent white -> new color] - color = next(ax._get_lines.prop_cycler)['color'] + color = next(ax._get_lines.prop_cycler)["color"] color = color_conv.to_rgba(color, alpha=alpha) - cmap = LinearSegmentedColormap.from_list(labels[i], - [(1.0, 1.0, 1.0, 0.0), - color]) - - ax.pcolormesh(times, freqs, spec, - cmap=cmap, - norm=LogNorm(vmin=ref_min, vmax=ref_max), - shading='gouraud', - label=labels[i]) + cmap = LinearSegmentedColormap.from_list( + labels[i], [(1.0, 1.0, 1.0, 0.0), color] + ) + + ax.pcolormesh( + times, + freqs, + spec, + cmap=cmap, + norm=LogNorm(vmin=ref_min, vmax=ref_max), + shading="gouraud", + label=labels[i], + ) # Attach a 0x0 rect to the axis with the corresponding label # This way, it will show up in the legend ax.add_patch(Rectangle((0, 0), 0, 0, color=color, label=labels[i])) if new_axes: - ax.axis('tight') + ax.axis("tight") return ax def __ticker_midi_note(x, pos): - '''A ticker function for midi notes. + """Format midi notes for ticker decoration. Inputs x are interpreted as midi numbers, and converted to [NOTE][OCTAVE]+[cents]. - ''' - - NOTES = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] + """ + NOTES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] cents = float(np.mod(x, 1.0)) if cents >= 0.5: @@ -823,22 +816,21 @@ def __ticker_midi_note(x, pos): octave = int(x / 12) - 1 if cents == 0: - return '{:s}{:2d}'.format(NOTES[idx], octave) - return '{:s}{:2d}{:+02d}'.format(NOTES[idx], octave, int(cents * 100)) + return "{:s}{:2d}".format(NOTES[idx], octave) + return "{:s}{:2d}{:+02d}".format(NOTES[idx], octave, int(cents * 100)) def __ticker_midi_hz(x, pos): - '''A ticker function for midi pitches. + """Format midi pitches for ticker decoration. Inputs x are interpreted as midi numbers, and converted to Hz. - ''' - - return '{:g}'.format(midi_to_hz(x)) + """ + return "{:g}".format(midi_to_hz(x)) def ticker_notes(ax=None): - '''Set the y-axis of the given axes to MIDI notes + """Set the y-axis of the given axes to MIDI notes Parameters ---------- @@ -846,24 +838,24 @@ def ticker_notes(ax=None): The axes handle to apply the ticker. By default, uses the current axes handle. - ''' + """ ax, _ = __get_axes(ax=ax) ax.yaxis.set_major_formatter(FMT_MIDI_NOTE) # Get the tick labels and reset the vertical alignment for tick in ax.yaxis.get_ticklabels(): - tick.set_verticalalignment('baseline') + tick.set_verticalalignment("baseline") def ticker_pitch(ax=None): - '''Set the y-axis of the given axes to MIDI frequencies + """Set the y-axis of the given axes to MIDI frequencies Parameters ---------- ax : matplotlib.pyplot.axes The axes handle to apply the ticker. By default, uses the current axes handle. - ''' + """ ax, _ = __get_axes(ax=ax) ax.yaxis.set_major_formatter(FMT_MIDI_HZ) diff --git a/mir_eval/hierarchy.py b/mir_eval/hierarchy.py index c5015970..c5ad5be5 100644 --- a/mir_eval/hierarchy.py +++ b/mir_eval/hierarchy.py @@ -1,6 +1,6 @@ # CREATED:2015-09-16 14:46:47 by Brian McFee # -*- encoding: utf-8 -*- -'''Evaluation criteria for hierarchical structure analysis. +"""Evaluation criteria for hierarchical structure analysis. Hierarchical structure analysis seeks to annotate a track with a nested decomposition of the temporal elements of the piece, effectively providing @@ -39,7 +39,7 @@ Juan P. Bello. "Evaluating hierarchical structure in music annotations", Frontiers in Psychology, 2017. -''' +""" import collections import itertools @@ -53,7 +53,7 @@ def _round(t, frame_size): - '''Round a time-stamp to a specified resolution. + """Round a time-stamp to a specified resolution. Equivalent to ``t - np.mod(t, frame_size)``. @@ -68,7 +68,6 @@ def _round(t, frame_size): ---------- t : number or ndarray The time-stamp to round - frame_size : number > 0 The resolution to round to @@ -76,12 +75,12 @@ def _round(t, frame_size): ------- t_round : number The rounded time-stamp - ''' + """ return t - np.mod(t, float(frame_size)) def _hierarchy_bounds(intervals_hier): - '''Compute the covered time range of a hierarchical segmentation. + """Compute the covered time range of a hierarchical segmentation. Parameters ---------- @@ -94,14 +93,14 @@ def _hierarchy_bounds(intervals_hier): t_min : float t_max : float The minimum and maximum times spanned by the annotation - ''' + """ boundaries = list(itertools.chain(*list(itertools.chain(*intervals_hier)))) return min(boundaries), max(boundaries) def _align_intervals(int_hier, lab_hier, t_min=0.0, t_max=None): - '''Align a hierarchical annotation to span a fixed start and end time. + """Align a hierarchical annotation to span a fixed start and end time. Parameters ---------- @@ -110,10 +109,8 @@ def _align_intervals(int_hier, lab_hier, t_min=0.0, t_max=None): Hierarchical segment annotations, encoded as a list of list of intervals (int_hier) and list of list of strings (lab_hier) - t_min : None or number >= 0 The minimum time value for the segmentation - t_max : None or number >= t_min The maximum time value for the segmentation @@ -122,16 +119,22 @@ def _align_intervals(int_hier, lab_hier, t_min=0.0, t_max=None): intervals_hier : list of list of intervals labels_hier : list of list of str `int_hier` `lab_hier` aligned to span `[t_min, t_max]`. - ''' - return [list(_) for _ in zip(*[util.adjust_intervals(np.asarray(ival), - labels=lab, - t_min=t_min, - t_max=t_max) - for ival, lab in zip(int_hier, lab_hier)])] + """ + return [ + list(_) + for _ in zip( + *[ + util.adjust_intervals( + np.asarray(ival), labels=lab, t_min=t_min, t_max=t_max + ) + for ival, lab in zip(int_hier, lab_hier) + ] + ) + ] def _lca(intervals_hier, frame_size): - '''Compute the (sparse) least-common-ancestor (LCA) matrix for a + """Compute the (sparse) least-common-ancestor (LCA) matrix for a hierarchical segmentation. For any pair of frames ``(s, t)``, the LCA is the deepest level in @@ -143,7 +146,6 @@ def _lca(intervals_hier, frame_size): intervals_hier : list of ndarray An ordered list of segment interval arrays. The list is assumed to be ordered by increasing specificity (depth). - frame_size : number The length of the sample frames (in seconds) @@ -152,23 +154,22 @@ def _lca(intervals_hier, frame_size): lca_matrix : scipy.sparse.csr_matrix A sparse matrix such that ``lca_matrix[i, j]`` contains the depth of the deepest segment containing frames ``i`` and ``j``. - ''' - + """ frame_size = float(frame_size) # Figure out how many frames we need n_start, n_end = _hierarchy_bounds(intervals_hier) - n = int((_round(n_end, frame_size) - - _round(n_start, frame_size)) / frame_size) + n = int((_round(n_end, frame_size) - _round(n_start, frame_size)) / frame_size) # Initialize the LCA matrix lca_matrix = scipy.sparse.lil_matrix((n, n), dtype=np.uint8) for level, intervals in enumerate(intervals_hier, 1): - for ival in (_round(np.asarray(intervals), - frame_size) / frame_size).astype(int): + for ival in (_round(np.asarray(intervals), frame_size) / frame_size).astype( + int + ): idx = slice(ival[0], ival[1]) lca_matrix[idx, idx] = level @@ -176,7 +177,7 @@ def _lca(intervals_hier, frame_size): def _meet(intervals_hier, labels_hier, frame_size): - '''Compute the (sparse) least-common-ancestor (LCA) matrix for a + """Compute the (sparse) least-common-ancestor (LCA) matrix for a hierarchical segmentation. For any pair of frames ``(s, t)``, the LCA is the deepest level in @@ -188,11 +189,9 @@ def _meet(intervals_hier, labels_hier, frame_size): intervals_hier : list of ndarray An ordered list of segment interval arrays. The list is assumed to be ordered by increasing specificity (depth). - labels_hier : list of list of str ``labels_hier[i]`` contains the segment labels for the ``i``th layer of the annotations - frame_size : number The length of the sample frames (in seconds) @@ -201,23 +200,19 @@ def _meet(intervals_hier, labels_hier, frame_size): meet_matrix : scipy.sparse.csr_matrix A sparse matrix such that ``meet_matrix[i, j]`` contains the depth of the deepest segment label containing both ``i`` and ``j``. - ''' - + """ frame_size = float(frame_size) # Figure out how many frames we need n_start, n_end = _hierarchy_bounds(intervals_hier) - n = int((_round(n_end, frame_size) - - _round(n_start, frame_size)) / frame_size) + n = int((_round(n_end, frame_size) - _round(n_start, frame_size)) / frame_size) # Initialize the meet matrix meet_matrix = scipy.sparse.lil_matrix((n, n), dtype=np.uint8) - for level, (intervals, labels) in enumerate(zip(intervals_hier, - labels_hier), 1): - + for level, (intervals, labels) in enumerate(zip(intervals_hier, labels_hier), 1): # Encode the labels at this level lab_enc = util.index_labels(labels)[0] @@ -228,7 +223,7 @@ def _meet(intervals_hier, labels_hier, frame_size): int_frames = (_round(intervals, frame_size) / frame_size).astype(int) # For each intervals i, j where labels agree, update the meet matrix - for (seg_i, seg_j) in zip(*np.where(int_agree)): + for seg_i, seg_j in zip(*np.where(int_agree)): idx_i = slice(*list(int_frames[seg_i])) idx_j = slice(*list(int_frames[seg_j])) meet_matrix[idx_i, idx_j] = level @@ -239,7 +234,7 @@ def _meet(intervals_hier, labels_hier, frame_size): def _gauc(ref_lca, est_lca, transitive, window): - '''Generalized area under the curve (GAUC) + """Generalized area under the curve (GAUC) This function computes the normalized recall score for correctly ordering triples ``(q, i, j)`` where frames ``(q, i)`` are closer than @@ -248,6 +243,7 @@ def _gauc(ref_lca, est_lca, transitive, window): Parameters ---------- ref_lca : scipy.sparse + est_lca : scipy.sparse The least common ancestor matrices for the reference and estimated annotations @@ -273,12 +269,13 @@ def _gauc(ref_lca, est_lca, transitive, window): ------ ValueError If ``ref_lca`` and ``est_lca`` have different shapes - ''' + """ # Make sure we have the right number of frames if ref_lca.shape != est_lca.shape: - raise ValueError('Estimated and reference hierarchies ' - 'must have the same shape.') + raise ValueError( + "Estimated and reference hierarchies " "must have the same shape." + ) # How many frames? n = ref_lca.shape[0] @@ -294,7 +291,6 @@ def _gauc(ref_lca, est_lca, transitive, window): num_frames = 0 for query in range(n): - # Find all pairs i,j such that ref_lca[q, i] > ref_lca[q, j] results = slice(max(0, query - window), min(n, query + window)) @@ -311,11 +307,12 @@ def _gauc(ref_lca, est_lca, transitive, window): # (this also holds when the slice goes off the end of the array.) idx = min(query, window) - ref_score = np.concatenate((ref_score[:idx], ref_score[idx+1:])) - est_score = np.concatenate((est_score[:idx], est_score[idx+1:])) + ref_score = np.concatenate((ref_score[:idx], ref_score[idx + 1 :])) + est_score = np.concatenate((est_score[:idx], est_score[idx + 1 :])) - inversions, normalizer = _compare_frame_rankings(ref_score, est_score, - transitive=transitive) + inversions, normalizer = _compare_frame_rankings( + ref_score, est_score, transitive=transitive + ) if normalizer: score += 1.0 - inversions / float(normalizer) @@ -332,7 +329,7 @@ def _gauc(ref_lca, est_lca, transitive, window): def _count_inversions(a, b): - '''Count the number of inversions in two numpy arrays: + """Count the number of inversions in two numpy arrays: # points i, j where a[i] >= b[j] @@ -348,8 +345,7 @@ def _count_inversions(a, b): ------- inversions : int The number of detected inversions - ''' - + """ a, a_counts = np.unique(a, return_counts=True) b, b_counts = np.unique(b, return_counts=True) @@ -368,7 +364,7 @@ def _count_inversions(a, b): def _compare_frame_rankings(ref, est, transitive=False): - '''Compute the number of ranking disagreements in two lists. + """Compute the number of ranking disagreements in two lists. Parameters ---------- @@ -376,7 +372,6 @@ def _compare_frame_rankings(ref, est, transitive=False): est : np.ndarray, shape=(n,) Reference and estimate ranked lists. `ref[i]` is the relevance score for point `i`. - transitive : bool If true, all pairs of reference levels are compared. If false, only adjacent pairs of reference levels are compared. @@ -386,21 +381,19 @@ def _compare_frame_rankings(ref, est, transitive=False): inversions : int The number of pairs of indices `i, j` where `ref[i] < ref[j]` but `est[i] >= est[j]`. - normalizer : float The total number of pairs (i, j) under consideration. If transitive=True, then this is |{(i,j) : ref[i] < ref[j]}| If transitive=False, then this is |{i,j) : ref[i] +1 = ref[j]}| - ''' - + """ idx = np.argsort(ref) ref_sorted = ref[idx] est_sorted = est[idx] # Find the break-points in ref_sorted - levels, positions, counts = np.unique(ref_sorted, - return_index=True, - return_counts=True) + levels, positions, counts = np.unique( + ref_sorted, return_index=True, return_counts=True + ) positions = list(positions) positions.append(len(ref_sorted)) @@ -408,8 +401,7 @@ def _compare_frame_rankings(ref, est, transitive=False): index = collections.defaultdict(lambda: slice(0)) ref_map = collections.defaultdict(lambda: 0) - for level, cnt, start, end in zip(levels, counts, - positions[:-1], positions[1:]): + for level, cnt, start, end in zip(levels, counts, positions[:-1], positions[1:]): index[level] = slice(start, end) ref_map[level] = cnt @@ -418,7 +410,7 @@ def _compare_frame_rankings(ref, est, transitive=False): if transitive: level_pairs = itertools.combinations(levels, 2) else: - level_pairs = [(i, i+1) for i in levels] + level_pairs = [(i, i + 1) for i in levels] level_pairs, lcounter = itertools.tee(level_pairs) @@ -430,14 +422,15 @@ def _compare_frame_rankings(ref, est, transitive=False): inversions = 0 for level_1, level_2 in level_pairs: - inversions += _count_inversions(est_sorted[index[level_1]], - est_sorted[index[level_2]]) + inversions += _count_inversions( + est_sorted[index[level_1]], est_sorted[index[level_2]] + ) return inversions, float(normalizer) def validate_hier_intervals(intervals_hier): - '''Validate a hierarchical segment annotation. + """Validate a hierarchical segment annotation. Parameters ---------- @@ -450,8 +443,7 @@ def validate_hier_intervals(intervals_hier): segmentation. If any segmentation does not start at 0. - ''' - + """ # Synthesize a label array for the top layer. label_top = util.generate_labels(intervals_hier[0]) @@ -460,21 +452,27 @@ def validate_hier_intervals(intervals_hier): for level, intervals in enumerate(intervals_hier[1:], 1): # Make sure this level is consistent with the root label_current = util.generate_labels(intervals) - validate_structure(intervals_hier[0], label_top, - intervals, label_current) + validate_structure(intervals_hier[0], label_top, intervals, label_current) # Make sure all previous boundaries are accounted for new_bounds = set(util.intervals_to_boundaries(intervals)) if boundaries - new_bounds: - warnings.warn('Segment hierarchy is inconsistent ' - 'at level {:d}'.format(level)) + warnings.warn( + "Segment hierarchy is inconsistent " "at level {:d}".format(level) + ) boundaries |= new_bounds -def tmeasure(reference_intervals_hier, estimated_intervals_hier, - transitive=False, window=15.0, frame_size=0.1, beta=1.0): - '''Computes the tree measures for hierarchical segment annotations. +def tmeasure( + reference_intervals_hier, + estimated_intervals_hier, + transitive=False, + window=15.0, + frame_size=0.1, + beta=1.0, +): + """Compute the tree measures for hierarchical segment annotations. Parameters ---------- @@ -483,21 +481,16 @@ def tmeasure(reference_intervals_hier, estimated_intervals_hier, (in seconds) for the ``i`` th layer of the annotations. Layers are ordered from top to bottom, so that the last list of intervals should be the most specific. - estimated_intervals_hier : list of ndarray Like ``reference_intervals_hier`` but for the estimated annotation - transitive : bool whether to compute the t-measures using transitivity or not. - window : float > 0 size of the window (in seconds). For each query frame q, result frames are only counted within q +- window. - frame_size : float > 0 length (in seconds) of frames. The frame size cannot be longer than the window. - beta : float > 0 beta parameter for the F-measure. @@ -505,10 +498,8 @@ def tmeasure(reference_intervals_hier, estimated_intervals_hier, ------- t_precision : number [0, 1] T-measure Precision - t_recall : number [0, 1] T-measure Recall - t_measure : number [0, 1] F-beta measure for ``(t_precision, t_recall)`` @@ -520,19 +511,21 @@ def tmeasure(reference_intervals_hier, estimated_intervals_hier, If the input hierarchies have different time durations If ``frame_size > window`` or ``frame_size <= 0`` - ''' - + """ # Compute the number of frames in the window if frame_size <= 0: - raise ValueError('frame_size ({:.2f}) must be a positive ' - 'number.'.format(frame_size)) + raise ValueError( + "frame_size ({:.2f}) must be a positive " "number.".format(frame_size) + ) if window is None: window_frames = None else: if frame_size > window: - raise ValueError('frame_size ({:.2f}) cannot exceed ' - 'window ({:.2f})'.format(frame_size, window)) + raise ValueError( + "frame_size ({:.2f}) cannot exceed " + "window ({:.2f})".format(frame_size, window) + ) window_frames = int(_round(window, frame_size) / frame_size) @@ -553,10 +546,15 @@ def tmeasure(reference_intervals_hier, estimated_intervals_hier, return t_precision, t_recall, t_measure -def lmeasure(reference_intervals_hier, reference_labels_hier, - estimated_intervals_hier, estimated_labels_hier, - frame_size=0.1, beta=1.0): - '''Computes the tree measures for hierarchical segment annotations. +def lmeasure( + reference_intervals_hier, + reference_labels_hier, + estimated_intervals_hier, + estimated_labels_hier, + frame_size=0.1, + beta=1.0, +): + """Compute the tree measures for hierarchical segment annotations. Parameters ---------- @@ -565,20 +563,16 @@ def lmeasure(reference_intervals_hier, reference_labels_hier, (in seconds) for the ``i`` th layer of the annotations. Layers are ordered from top to bottom, so that the last list of intervals should be the most specific. - reference_labels_hier : list of list of str ``reference_labels_hier[i]`` contains the segment labels for the ``i``th layer of the annotations - estimated_intervals_hier : list of ndarray estimated_labels_hier : list of ndarray Like ``reference_intervals_hier`` and ``reference_labels_hier`` but for the estimated annotation - frame_size : float > 0 length (in seconds) of frames. The frame size cannot be longer than the window. - beta : float > 0 beta parameter for the F-measure. @@ -586,10 +580,8 @@ def lmeasure(reference_intervals_hier, reference_labels_hier, ------- l_precision : number [0, 1] L-measure Precision - l_recall : number [0, 1] L-measure Recall - l_measure : number [0, 1] F-beta measure for ``(l_precision, l_recall)`` @@ -601,22 +593,20 @@ def lmeasure(reference_intervals_hier, reference_labels_hier, If the input hierarchies have different time durations If ``frame_size > window`` or ``frame_size <= 0`` - ''' - + """ # Compute the number of frames in the window if frame_size <= 0: - raise ValueError('frame_size ({:.2f}) must be a positive ' - 'number.'.format(frame_size)) + raise ValueError( + "frame_size ({:.2f}) must be a positive " "number.".format(frame_size) + ) # Validate the hierarchical segmentations validate_hier_intervals(reference_intervals_hier) validate_hier_intervals(estimated_intervals_hier) # Build the least common ancestor matrices - ref_meet = _meet(reference_intervals_hier, reference_labels_hier, - frame_size) - est_meet = _meet(estimated_intervals_hier, estimated_labels_hier, - frame_size) + ref_meet = _meet(reference_intervals_hier, reference_labels_hier, frame_size) + est_meet = _meet(estimated_intervals_hier, estimated_labels_hier, frame_size) # Compute precision and recall l_recall = _gauc(ref_meet, est_meet, True, None) @@ -627,9 +617,10 @@ def lmeasure(reference_intervals_hier, reference_labels_hier, return l_precision, l_recall, l_measure -def evaluate(ref_intervals_hier, ref_labels_hier, - est_intervals_hier, est_labels_hier, **kwargs): - '''Compute all hierarchical structure metrics for the given reference and +def evaluate( + ref_intervals_hier, ref_labels_hier, est_intervals_hier, est_labels_hier, **kwargs +): + r"""Compute all hierarchical structure metrics for the given reference and estimated annotations. Examples @@ -676,7 +667,6 @@ def evaluate(ref_intervals_hier, ref_labels_hier, 'T-Recall full': 0.6523334654992341, 'T-Recall reduced': 0.60799919710921635} - Parameters ---------- ref_intervals_hier : list of list-like @@ -687,13 +677,12 @@ def evaluate(ref_intervals_hier, ref_labels_hier, of segmentations. Each segmentation itself is a list (or list-like) of intervals (\*_intervals_hier) and a list of lists of labels (\*_labels_hier). - - kwargs + **kwargs additional keyword arguments to the evaluation metrics. Returns ------- - scores : OrderedDict + scores : OrderedDict Dictionary of scores, where the key is the metric name (str) and the value is the (float) score achieved. @@ -704,48 +693,47 @@ def evaluate(ref_intervals_hier, ref_labels_hier, ------ ValueError Thrown when the provided annotations are not valid. - ''' - + """ # First, find the maximum length of the reference _, t_end = _hierarchy_bounds(ref_intervals_hier) # Pre-process the intervals to match the range of the reference, # and start at 0 - ref_intervals_hier, ref_labels_hier = _align_intervals(ref_intervals_hier, - ref_labels_hier, - t_min=0.0, - t_max=None) + ref_intervals_hier, ref_labels_hier = _align_intervals( + ref_intervals_hier, ref_labels_hier, t_min=0.0, t_max=None + ) - est_intervals_hier, est_labels_hier = _align_intervals(est_intervals_hier, - est_labels_hier, - t_min=0.0, - t_max=t_end) + est_intervals_hier, est_labels_hier = _align_intervals( + est_intervals_hier, est_labels_hier, t_min=0.0, t_max=t_end + ) scores = collections.OrderedDict() # Force the transitivity setting - kwargs['transitive'] = False - (scores['T-Precision reduced'], - scores['T-Recall reduced'], - scores['T-Measure reduced']) = util.filter_kwargs(tmeasure, - ref_intervals_hier, - est_intervals_hier, - **kwargs) - - kwargs['transitive'] = True - (scores['T-Precision full'], - scores['T-Recall full'], - scores['T-Measure full']) = util.filter_kwargs(tmeasure, - ref_intervals_hier, - est_intervals_hier, - **kwargs) - - (scores['L-Precision'], - scores['L-Recall'], - scores['L-Measure']) = util.filter_kwargs(lmeasure, - ref_intervals_hier, - ref_labels_hier, - est_intervals_hier, - est_labels_hier, - **kwargs) + kwargs["transitive"] = False + ( + scores["T-Precision reduced"], + scores["T-Recall reduced"], + scores["T-Measure reduced"], + ) = util.filter_kwargs(tmeasure, ref_intervals_hier, est_intervals_hier, **kwargs) + + kwargs["transitive"] = True + ( + scores["T-Precision full"], + scores["T-Recall full"], + scores["T-Measure full"], + ) = util.filter_kwargs(tmeasure, ref_intervals_hier, est_intervals_hier, **kwargs) + + ( + scores["L-Precision"], + scores["L-Recall"], + scores["L-Measure"], + ) = util.filter_kwargs( + lmeasure, + ref_intervals_hier, + ref_labels_hier, + est_intervals_hier, + est_labels_hier, + **kwargs + ) return scores diff --git a/mir_eval/io.py b/mir_eval/io.py index 072cb3d1..37ed9efc 100644 --- a/mir_eval/io.py +++ b/mir_eval/io.py @@ -1,6 +1,4 @@ -""" -Functions for loading in annotations from files in different formats. -""" +"""Functions for loading annotations from files in different formats.""" import contextlib import numpy as np @@ -15,27 +13,26 @@ @contextlib.contextmanager def _open(file_or_str, **kwargs): - '''Either open a file handle, or use an existing file-like object. + """Either open a file handle, or use an existing file-like object. This will behave as the `open` function if `file_or_str` is a string. If `file_or_str` has the `read` attribute, it will return `file_or_str`. Otherwise, an `IOError` is raised. - ''' - if hasattr(file_or_str, 'read'): + """ + if hasattr(file_or_str, "read"): yield file_or_str elif isinstance(file_or_str, str): with open(file_or_str, **kwargs) as file_desc: yield file_desc else: - raise IOError('Invalid file-or-str object: {}'.format(file_or_str)) + raise IOError("Invalid file-or-str object: {}".format(file_or_str)) -def load_delimited(filename, converters, delimiter=r'\s+', comment='#'): - r"""Utility function for loading in data from an annotation file where columns - are delimited. The number of columns is inferred from the length of - the provided converters list. +def load_delimited(filename, converters, delimiter=r"\s+", comment="#"): + r"""Load data from an annotation file where columns are delimited. + The number of columns is inferred from the length of the provided converters list. Examples -------- @@ -48,12 +45,15 @@ def load_delimited(filename, converters, delimiter=r'\s+', comment='#'): ---------- filename : str Path to the annotation file + converters : list of functions Each entry in column ``n`` of the file will be cast by the function ``converters[n]``. + delimiter : str Separator regular expression. By default, lines will be split by any amount of whitespace. + comment : str or None Comment regular expression. Any lines beginning with this string or pattern will be ignored. @@ -65,7 +65,6 @@ def load_delimited(filename, converters, delimiter=r'\s+', comment='#'): columns : tuple of lists Each list in this tuple corresponds to values in one of the columns in the file. - """ # Initialize list of empty lists n_columns = len(converters) @@ -78,7 +77,7 @@ def load_delimited(filename, converters, delimiter=r'\s+', comment='#'): if comment is None: commenter = None else: - commenter = re.compile('^{}'.format(comment)) + commenter = re.compile("^{}".format(comment)) # Note: we do io manually here for two reasons. # 1. The csv module has difficulties with unicode, which may lead @@ -86,7 +85,7 @@ def load_delimited(filename, converters, delimiter=r'\s+', comment='#'): # # 2. numpy's text loader does not handle non-numeric data # - with _open(filename, mode='r') as input_file: + with _open(filename, mode="r") as input_file: for row, line in enumerate(input_file, 1): # Skip commented lines if comment is not None and commenter.match(line): @@ -97,19 +96,22 @@ def load_delimited(filename, converters, delimiter=r'\s+', comment='#'): # Throw a helpful error if we got an unexpected # of columns if n_columns != len(data): - raise ValueError('Expected {} columns, got {} at ' - '{}:{:d}:\n\t{}'.format(n_columns, len(data), - filename, row, line)) + raise ValueError( + "Expected {} columns, got {} at " + "{}:{:d}:\n\t{}".format(n_columns, len(data), filename, row, line) + ) for value, column, converter in zip(data, columns, converters): # Try converting the value, throw a helpful error on failure try: converted_value = converter(value) except: - raise ValueError("Couldn't convert value {} using {} " - "found at {}:{:d}:\n\t{}".format( - value, converter.__name__, filename, - row, line)) + raise ValueError( + "Couldn't convert value {} using {} " + "found at {}:{:d}:\n\t{}".format( + value, converter.__name__, filename, row, line + ) + ) column.append(converted_value) # Sane output @@ -119,7 +121,7 @@ def load_delimited(filename, converters, delimiter=r'\s+', comment='#'): return columns -def load_events(filename, delimiter=r'\s+', comment='#'): +def load_events(filename, delimiter=r"\s+", comment="#"): r"""Import time-stamp events from an annotation file. The file should consist of a single column of numeric values corresponding to the event times. This is primarily useful for processing events which lack duration, @@ -129,9 +131,11 @@ def load_events(filename, delimiter=r'\s+', comment='#'): ---------- filename : str Path to the annotation file + delimiter : str Separator regular expression. By default, lines will be split by any amount of whitespace. + comment : str or None Comment regular expression. Any lines beginning with this string or pattern will be ignored. @@ -145,8 +149,7 @@ def load_events(filename, delimiter=r'\s+', comment='#'): """ # Use our universal function to load in the events - events = load_delimited(filename, [float], - delimiter=delimiter, comment=comment) + events = load_delimited(filename, [float], delimiter=delimiter, comment=comment) events = np.array(events) # Validate them, but throw a warning in place of an error try: @@ -157,7 +160,7 @@ def load_events(filename, delimiter=r'\s+', comment='#'): return events -def load_labeled_events(filename, delimiter=r'\s+', comment='#'): +def load_labeled_events(filename, delimiter=r"\s+", comment="#"): r"""Import labeled time-stamp events from an annotation file. The file should consist of two columns; the first having numeric values corresponding to the event times and the second having string labels for each event. This @@ -168,9 +171,11 @@ def load_labeled_events(filename, delimiter=r'\s+', comment='#'): ---------- filename : str Path to the annotation file + delimiter : str Separator regular expression. By default, lines will be split by any amount of whitespace. + comment : str or None Comment regular expression. Any lines beginning with this string or pattern will be ignored. @@ -186,9 +191,9 @@ def load_labeled_events(filename, delimiter=r'\s+', comment='#'): """ # Use our universal function to load in the events - events, labels = load_delimited(filename, [float, str], - delimiter=delimiter, - comment=comment) + events, labels = load_delimited( + filename, [float, str], delimiter=delimiter, comment=comment + ) events = np.array(events) # Validate them, but throw a warning in place of an error try: @@ -199,7 +204,7 @@ def load_labeled_events(filename, delimiter=r'\s+', comment='#'): return events, labels -def load_intervals(filename, delimiter=r'\s+', comment='#'): +def load_intervals(filename, delimiter=r"\s+", comment="#"): r"""Import intervals from an annotation file. The file should consist of two columns of numeric values corresponding to start and end time of each interval. This is primarily useful for processing events which span a @@ -209,9 +214,11 @@ def load_intervals(filename, delimiter=r'\s+', comment='#'): ---------- filename : str Path to the annotation file + delimiter : str Separator regular expression. By default, lines will be split by any amount of whitespace. + comment : str or None Comment regular expression. Any lines beginning with this string or pattern will be ignored. @@ -225,9 +232,9 @@ def load_intervals(filename, delimiter=r'\s+', comment='#'): """ # Use our universal function to load in the events - starts, ends = load_delimited(filename, [float, float], - delimiter=delimiter, - comment=comment) + starts, ends = load_delimited( + filename, [float, float], delimiter=delimiter, comment=comment + ) # Stack into an interval matrix intervals = np.array([starts, ends]).T # Validate them, but throw a warning in place of an error @@ -239,7 +246,7 @@ def load_intervals(filename, delimiter=r'\s+', comment='#'): return intervals -def load_labeled_intervals(filename, delimiter=r'\s+', comment='#'): +def load_labeled_intervals(filename, delimiter=r"\s+", comment="#"): r"""Import labeled intervals from an annotation file. The file should consist of three columns: Two consisting of numeric values corresponding to start and end time of each interval and a third corresponding to the label of @@ -250,9 +257,11 @@ def load_labeled_intervals(filename, delimiter=r'\s+', comment='#'): ---------- filename : str Path to the annotation file + delimiter : str Separator regular expression. By default, lines will be split by any amount of whitespace. + comment : str or None Comment regular expression. Any lines beginning with this string or pattern will be ignored. @@ -268,9 +277,9 @@ def load_labeled_intervals(filename, delimiter=r'\s+', comment='#'): """ # Use our universal function to load in the events - starts, ends, labels = load_delimited(filename, [float, float, str], - delimiter=delimiter, - comment=comment) + starts, ends, labels = load_delimited( + filename, [float, float, str], delimiter=delimiter, comment=comment + ) # Stack into an interval matrix intervals = np.array([starts, ends]).T # Validate them, but throw a warning in place of an error @@ -282,7 +291,7 @@ def load_labeled_intervals(filename, delimiter=r'\s+', comment='#'): return intervals, labels -def load_time_series(filename, delimiter=r'\s+', comment='#'): +def load_time_series(filename, delimiter=r"\s+", comment="#"): r"""Import a time series from an annotation file. The file should consist of two columns of numeric values corresponding to the time and value of each sample of the time series. @@ -291,9 +300,11 @@ def load_time_series(filename, delimiter=r'\s+', comment='#'): ---------- filename : str Path to the annotation file + delimiter : str Separator regular expression. By default, lines will be split by any amount of whitespace. + comment : str or None Comment regular expression. Any lines beginning with this string or pattern will be ignored. @@ -309,9 +320,9 @@ def load_time_series(filename, delimiter=r'\s+', comment='#'): """ # Use our universal function to load in the events - times, values = load_delimited(filename, [float, float], - delimiter=delimiter, - comment=comment) + times, values = load_delimited( + filename, [float, float], delimiter=delimiter, comment=comment + ) times = np.array(times) values = np.array(values) @@ -319,7 +330,7 @@ def load_time_series(filename, delimiter=r'\s+', comment='#'): def load_patterns(filename): - """Loads the patters contained in the filename and puts them into a list + """Load the patterns contained in the filename and puts them into a list of patterns, each pattern being a list of occurrence, and each occurrence being a list of (onset, midi) pairs. @@ -360,16 +371,14 @@ def load_patterns(filename): pattern2 = [occ1, occ2] pattern_list = [pattern1, pattern2] - """ - # List with all the patterns pattern_list = [] # Current pattern, which will contain all occs pattern = [] # Current occurrence, containing (onset, midi) occurrence = [] - with _open(filename, mode='r') as input_file: + with _open(filename, mode="r") as input_file: for line in input_file.readlines(): if "pattern" in line: if occurrence != []: @@ -398,7 +407,7 @@ def load_patterns(filename): def load_wav(path, mono=True): - """Loads a .wav file as a numpy array using ``scipy.io.wavfile``. + """Load a .wav file as a numpy array using ``scipy.io.wavfile``. Parameters ---------- @@ -414,27 +423,24 @@ def load_wav(path, mono=True): Array of audio samples, normalized to the range [-1., 1.] fs : int Sampling rate of the audio data - """ - fs, audio_data = scipy.io.wavfile.read(path) # Make float in range [-1, 1] - if audio_data.dtype == 'int8': - audio_data = audio_data/float(2**8) - elif audio_data.dtype == 'int16': - audio_data = audio_data/float(2**16) - elif audio_data.dtype == 'int32': - audio_data = audio_data/float(2**24) + if audio_data.dtype == "int8": + audio_data = audio_data / float(2**8) + elif audio_data.dtype == "int16": + audio_data = audio_data / float(2**16) + elif audio_data.dtype == "int32": + audio_data = audio_data / float(2**24) else: - raise ValueError('Got unexpected .wav data type ' - '{}'.format(audio_data.dtype)) + raise ValueError("Got unexpected .wav data type " "{}".format(audio_data.dtype)) # Optionally convert to mono if mono and audio_data.ndim != 1: audio_data = audio_data.mean(axis=1) return audio_data, fs -def load_valued_intervals(filename, delimiter=r'\s+', comment='#'): +def load_valued_intervals(filename, delimiter=r"\s+", comment="#"): r"""Import valued intervals from an annotation file. The file should consist of three columns: Two consisting of numeric values corresponding to start and end time of each interval and a third, also of numeric values, @@ -446,9 +452,11 @@ def load_valued_intervals(filename, delimiter=r'\s+', comment='#'): ---------- filename : str Path to the annotation file + delimiter : str Separator regular expression. By default, lines will be split by any amount of whitespace. + comment : str or None Comment regular expression. Any lines beginning with this string or pattern will be ignored. @@ -464,9 +472,9 @@ def load_valued_intervals(filename, delimiter=r'\s+', comment='#'): """ # Use our universal function to load in the events - starts, ends, values = load_delimited(filename, [float, float, float], - delimiter=delimiter, - comment=comment) + starts, ends, values = load_delimited( + filename, [float, float, float], delimiter=delimiter, comment=comment + ) # Stack into an interval matrix intervals = np.array([starts, ends]).T # Validate them, but throw a warning in place of an error @@ -481,7 +489,7 @@ def load_valued_intervals(filename, delimiter=r'\s+', comment='#'): return intervals, values -def load_key(filename, delimiter=r'\s+', comment='#'): +def load_key(filename, delimiter=r"\s+", comment="#"): r"""Load key labels from an annotation file. The file should consist of two string columns: One denoting the key scale degree (semitone), and the other denoting the mode (major or minor). The file @@ -491,9 +499,11 @@ def load_key(filename, delimiter=r'\s+', comment='#'): ---------- filename : str Path to the annotation file + delimiter : str Separator regular expression. By default, lines will be split by any amount of whitespace. + comment : str or None Comment regular expression. Any lines beginning with this string or pattern will be ignored. @@ -507,14 +517,14 @@ def load_key(filename, delimiter=r'\s+', comment='#'): """ # Use our universal function to load the key and mode strings - scale, mode = load_delimited(filename, [str, str], - delimiter=delimiter, - comment=comment) + scale, mode = load_delimited( + filename, [str, str], delimiter=delimiter, comment=comment + ) if len(scale) != 1: - raise ValueError('Key file should contain only one line.') + raise ValueError("Key file should contain only one line.") scale, mode = scale[0], mode[0] # Join with a space - key_string = '{} {}'.format(scale, mode) + key_string = "{} {}".format(scale, mode) # Validate them, but throw a warning in place of an error try: key.validate_key(key_string) @@ -524,7 +534,7 @@ def load_key(filename, delimiter=r'\s+', comment='#'): return key_string -def load_tempo(filename, delimiter=r'\s+', comment='#'): +def load_tempo(filename, delimiter=r"\s+", comment="#"): r"""Load tempo estimates from an annotation file in MIREX format. The file should consist of three numeric columns: the first two correspond to tempo estimates (in beats-per-minute), and the third @@ -535,9 +545,11 @@ def load_tempo(filename, delimiter=r'\s+', comment='#'): ---------- filename : str Path to the annotation file + delimiter : str Separator regular expression. By default, lines will be split by any amount of whitespace. + comment : str or None Comment regular expression. Any lines beginning with this string or pattern will be ignored. @@ -548,20 +560,19 @@ def load_tempo(filename, delimiter=r'\s+', comment='#'): ------- tempi : np.ndarray, non-negative The two tempo estimates - weight : float [0, 1] The relative importance of ``tempi[0]`` compared to ``tempi[1]`` """ # Use our universal function to load the key and mode strings - t1, t2, weight = load_delimited(filename, [float, float, float], - delimiter=delimiter, - comment=comment) + t1, t2, weight = load_delimited( + filename, [float, float, float], delimiter=delimiter, comment=comment + ) weight = weight[0] tempi = np.concatenate([t1, t2]) if len(t1) != 1: - raise ValueError('Tempo file should contain only one line.') + raise ValueError("Tempo file should contain only one line.") # Validate them, but throw a warning in place of an error try: @@ -570,17 +581,20 @@ def load_tempo(filename, delimiter=r'\s+', comment='#'): warnings.warn(error.args[0]) if not 0 <= weight <= 1: - raise ValueError('Invalid weight: {}'.format(weight)) + raise ValueError("Invalid weight: {}".format(weight)) return tempi, weight -def load_ragged_time_series(filename, dtype=float, delimiter=r'\s+', - header=False, comment='#'): - r"""Utility function for loading in data from a delimited time series - annotation file with a variable number of columns. - Assumes that column 0 contains time stamps and columns 1 through n contain - values. n may be variable from time stamp to time stamp. +def load_ragged_time_series( + filename, dtype=float, delimiter=r"\s+", header=False, comment="#" +): + r"""Load data from a delimited time series annotation file with + a variable number of columns. + + This function assumes that column 0 contains time stamps and + columns 1 through n contain values. + n may be variable from time stamp to time stamp. Examples -------- @@ -595,14 +609,18 @@ def load_ragged_time_series(filename, dtype=float, delimiter=r'\s+', ---------- filename : str Path to the annotation file + dtype : function Data type to apply to values columns. + delimiter : str Separator regular expression. By default, lines will be split by any amount of whitespace. + header : bool Indicates whether a header row is present or not. By default, assumes no header is present. + comment : str or None Comment regular expression. Any lines beginning with this string or pattern will be ignored. @@ -628,13 +646,13 @@ def load_ragged_time_series(filename, dtype=float, delimiter=r'\s+', if comment is None: commenter = None else: - commenter = re.compile('^{}'.format(comment)) + commenter = re.compile("^{}".format(comment)) if header: start_row = 1 else: start_row = 0 - with _open(filename, mode='r') as input_file: + with _open(filename, mode="r") as input_file: for row, line in enumerate(input_file, start_row): # If this is a comment line, skip it if comment is not None and commenter.match(line): @@ -645,10 +663,12 @@ def load_ragged_time_series(filename, dtype=float, delimiter=r'\s+', try: converted_time = float(data[0]) except (TypeError, ValueError) as exe: - raise ValueError("Couldn't convert value {} using {} " - "found at {}:{:d}:\n\t{}".format( - data[0], float.__name__, - filename, row, line)) from exe + raise ValueError( + "Couldn't convert value {} using {} " + "found at {}:{:d}:\n\t{}".format( + data[0], float.__name__, filename, row, line + ) + ) from exe times.append(converted_time) # cast values to a numpy array. time stamps with no values are cast @@ -656,10 +676,12 @@ def load_ragged_time_series(filename, dtype=float, delimiter=r'\s+', try: converted_value = np.array(data[1:], dtype=dtype) except (TypeError, ValueError) as exe: - raise ValueError("Couldn't convert value {} using {} " - "found at {}:{:d}:\n\t{}".format( - data[1:], dtype.__name__, - filename, row, line)) from exe + raise ValueError( + "Couldn't convert value {} using {} " + "found at {}:{:d}:\n\t{}".format( + data[1:], dtype.__name__, filename, row, line + ) + ) from exe values.append(converted_value) return np.array(times), values diff --git a/mir_eval/key.py b/mir_eval/key.py index f72e8ad3..c0302385 100644 --- a/mir_eval/key.py +++ b/mir_eval/key.py @@ -1,4 +1,4 @@ -''' +""" Key Detection involves determining the underlying key (distribution of notes and note transitions) in a piece of music. Key detection algorithms are evaluated by comparing their estimated key to a ground-truth reference key and @@ -16,49 +16,66 @@ ------- * :func:`mir_eval.key.weighted_score`: Heuristic scoring of the relation of two keys. -''' +""" import collections from . import util -KEY_TO_SEMITONE = {'c': 0, 'c#': 1, 'db': 1, 'd': 2, 'd#': 3, 'eb': 3, 'e': 4, - 'f': 5, 'f#': 6, 'gb': 6, 'g': 7, 'g#': 8, 'ab': 8, 'a': 9, - 'a#': 10, 'bb': 10, 'b': 11, 'x': None} +KEY_TO_SEMITONE = { + "c": 0, + "c#": 1, + "db": 1, + "d": 2, + "d#": 3, + "eb": 3, + "e": 4, + "f": 5, + "f#": 6, + "gb": 6, + "g": 7, + "g#": 8, + "ab": 8, + "a": 9, + "a#": 10, + "bb": 10, + "b": 11, + "x": None, +} def validate_key(key): - """Checks that a key is well-formatted, e.g. in the form ``'C# major'``. - The Key can be 'X' if it is not possible to categorize the Key and mode - can be 'other' if it can't be categorized as major or minor. + """Check that a key is well-formatted, e.g. in the form ``'C# major'``. + The Key can be 'X' if it is not possible to categorize the Key and mode + can be 'other' if it can't be categorized as major or minor. Parameters ---------- key : str Key to verify """ - if len(key.split()) != 2 \ - and not (len(key.split()) and key.lower() == 'x'): - raise ValueError("'{}' is not in the form '(key) (mode)' " - "or 'X'".format(key)) - if key.lower() != 'x': + if len(key.split()) != 2 and not (len(key.split()) and key.lower() == "x"): + raise ValueError("'{}' is not in the form '(key) (mode)' " "or 'X'".format(key)) + if key.lower() != "x": key, mode = key.split() - if key.lower() == 'x': + if key.lower() == "x": raise ValueError( "Mode {} is invalid; 'X' (Uncategorized) " - "doesn't have mode".format(mode)) + "doesn't have mode".format(mode) + ) if key.lower() not in KEY_TO_SEMITONE: raise ValueError( "Key {} is invalid; should be e.g. D or C# or Eb or " - "X (Uncategorized)".format(key)) - if mode not in ['major', 'minor', 'other']: + "X (Uncategorized)".format(key) + ) + if mode not in ["major", "minor", "other"]: raise ValueError( - "Mode '{}' is invalid; must be 'major', 'minor' or 'other'" - .format(mode)) + "Mode '{}' is invalid; must be 'major', 'minor' or 'other'".format(mode) + ) def validate(reference_key, estimated_key): - """Checks that the input annotations to a metric are valid key strings and + """Check that the input annotations to a metric are valid key strings and throws helpful errors if not. Parameters @@ -73,7 +90,7 @@ def validate(reference_key, estimated_key): def split_key_string(key): - """Splits a key string (of the form, e.g. ``'C# major'``), into a tuple of + """Split a key string (of the form, e.g. ``'C# major'``), into a tuple of ``(key, mode)`` where ``key`` is is an integer representing the semitone distance from C. @@ -89,7 +106,7 @@ def split_key_string(key): mode : str String representing the mode. """ - if key.lower() != 'x': + if key.lower() != "x": key, mode = key.split() else: mode = None @@ -97,7 +114,7 @@ def split_key_string(key): def weighted_score(reference_key, estimated_key): - """Computes a heuristic score which is weighted according to the + """Compute a heuristic score which is weighted according to the relationship of the reference and estimated key, as follows: +------------------------------------------------------+-------+ @@ -137,28 +154,31 @@ def weighted_score(reference_key, estimated_key): estimated_key, estimated_mode = split_key_string(estimated_key) # If keys are the same, return 1. if reference_key == estimated_key and reference_mode == estimated_mode: - return 1. + return 1.0 # If reference or estimated key are x and they are not the same key # then the result is 'Other'. if reference_key is None or estimated_key is None: - return 0. + return 0.0 # If keys are the same mode and a perfect fifth (differ by 7 semitones) - if (estimated_mode == reference_mode and - (estimated_key - reference_key) % 12 == 7): + if estimated_mode == reference_mode and (estimated_key - reference_key) % 12 == 7: return 0.5 # Estimated key is relative minor of reference key (9 semitones) - if (estimated_mode != reference_mode == 'major' and - (estimated_key - reference_key) % 12 == 9): + if ( + estimated_mode != reference_mode == "major" + and (estimated_key - reference_key) % 12 == 9 + ): return 0.3 # Estimated key is relative major of reference key (3 semitones) - if (estimated_mode != reference_mode == 'minor' and - (estimated_key - reference_key) % 12 == 3): + if ( + estimated_mode != reference_mode == "minor" + and (estimated_key - reference_key) % 12 == 3 + ): return 0.3 # If keys are in different modes and parallel (same key name) if estimated_mode != reference_mode and reference_key == estimated_key: return 0.2 # Otherwise return 0 - return 0. + return 0.0 def evaluate(reference_key, estimated_key, **kwargs): @@ -172,13 +192,11 @@ def evaluate(reference_key, estimated_key, **kwargs): Parameters ---------- - ref_key : str + reference_key : str Reference key string. - - ref_key : str + estimated_key : str Estimated key string. - - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -191,7 +209,8 @@ def evaluate(reference_key, estimated_key, **kwargs): # Compute all metrics scores = collections.OrderedDict() - scores['Weighted Score'] = util.filter_kwargs( - weighted_score, reference_key, estimated_key) + scores["Weighted Score"] = util.filter_kwargs( + weighted_score, reference_key, estimated_key + ) return scores diff --git a/mir_eval/melody.py b/mir_eval/melody.py index f548d8db..ab4d6e42 100644 --- a/mir_eval/melody.py +++ b/mir_eval/melody.py @@ -1,5 +1,5 @@ # CREATED:2014-03-07 by Justin Salamon -''' +""" Melody extraction algorithms aim to produce a sequence of frequency values corresponding to the pitch of the dominant melody from a musical recording. For evaluation, an estimated pitch series is evaluated against a @@ -61,7 +61,7 @@ the proportion of all frames correctly estimated by the algorithm, including whether non-melody frames where labeled by the algorithm as non-melody -''' +""" import numpy as np import scipy.interpolate @@ -71,7 +71,7 @@ def validate_voicing(ref_voicing, est_voicing): - """Checks that voicing inputs to a metric are in the correct format. + """Check that voicing inputs to a metric are in the correct format. Parameters ---------- @@ -91,16 +91,17 @@ def validate_voicing(ref_voicing, est_voicing): warnings.warn("Estimated melody has no voiced frames.") # Make sure they're the same length if ref_voicing.shape[0] != est_voicing.shape[0]: - raise ValueError('Reference and estimated voicing arrays should ' - 'be the same length.') + raise ValueError( + "Reference and estimated voicing arrays should " "be the same length." + ) for voicing in [ref_voicing, est_voicing]: # Make sure voicing is between 0 and 1 if np.logical_or(voicing < 0, voicing > 1).any(): - raise ValueError('Voicing arrays must be between 0 and 1.') + raise ValueError("Voicing arrays must be between 0 and 1.") def validate(ref_voicing, ref_cent, est_voicing, est_cent): - """Checks that voicing and frequency arrays are well-formed. To be used in + """Check that voicing and frequency arrays are well-formed. To be used in conjunction with :func:`mir_eval.melody.validate_voicing` Parameters @@ -120,11 +121,14 @@ def validate(ref_voicing, ref_cent, est_voicing, est_cent): if est_cent.size == 0: warnings.warn("Estimated frequency array is empty.") # Make sure they're the same length - if ref_voicing.shape[0] != ref_cent.shape[0] or \ - est_voicing.shape[0] != est_cent.shape[0] or \ - ref_cent.shape[0] != est_cent.shape[0]: - raise ValueError('All voicing and frequency arrays must have the ' - 'same length.') + if ( + ref_voicing.shape[0] != ref_cent.shape[0] + or est_voicing.shape[0] != est_cent.shape[0] + or ref_cent.shape[0] != est_cent.shape[0] + ): + raise ValueError( + "All voicing and frequency arrays must have the " "same length." + ) def hz2cents(freq_hz, base_frequency=10.0): @@ -138,10 +142,6 @@ def hz2cents(freq_hz, base_frequency=10.0): base_frequency : float Base frequency for conversion. (Default value = 10.0) - Returns - ------- - freq_cent : np.ndarray - Array of frequencies in cents, relative to base_frequency """ freq_cent = np.zeros(freq_hz.shape[0]) freq_nonz_ind = np.flatnonzero(freq_hz) @@ -185,8 +185,7 @@ def freq_to_voicing(frequencies, voicing=None): def constant_hop_timebase(hop, end_time): - """Generates a time series from 0 to ``end_time`` with times spaced ``hop`` - apart + """Generate a time series from 0 to ``end_time`` with times spaced ``hop`` apart Parameters ---------- @@ -203,14 +202,14 @@ def constant_hop_timebase(hop, end_time): """ # Compute new timebase. Rounding/linspace is to avoid float problems. end_time = np.round(end_time, 10) - times = np.linspace(0, hop * int(np.floor(end_time / hop)), - int(np.floor(end_time / hop)) + 1) + times = np.linspace( + 0, hop * int(np.floor(end_time / hop)), int(np.floor(end_time / hop)) + 1 + ) times = np.round(times, 10) return times -def resample_melody_series(times, frequencies, voicing, - times_new, kind='linear'): +def resample_melody_series(times, frequencies, voicing, times_new, kind="linear"): """Resamples frequency and voicing time series to a new timescale. Maintains any zero ("unvoiced") values in frequencies. @@ -247,14 +246,19 @@ def resample_melody_series(times, frequencies, voicing, # Warn when the delta between the original times is not constant, # unless times[0] == 0. and frequencies[0] == frequencies[1] (see logic at # the beginning of to_cent_voicing) - if not (np.allclose(np.diff(times), np.diff(times).mean()) or - (np.allclose(np.diff(times[1:]), np.diff(times[1:]).mean()) and - frequencies[0] == frequencies[1])): + if not ( + np.allclose(np.diff(times), np.diff(times).mean()) + or ( + np.allclose(np.diff(times[1:]), np.diff(times[1:]).mean()) + and frequencies[0] == frequencies[1] + ) + ): warnings.warn( "Non-uniform timescale passed to resample_melody_series. Pitch " "will be linearly interpolated, which will result in undesirable " "behavior if silences are indicated by missing values. Silences " - "should be indicated by nonpositive frequency values.") + "should be indicated by nonpositive frequency values." + ) # Round to avoid floating point problems times = np.round(times, 10) times_new = np.round(times_new, 10) @@ -264,7 +268,7 @@ def resample_melody_series(times, frequencies, voicing, frequencies = np.append(frequencies, 0) voicing = np.append(voicing, 0) # We need to fix zero transitions if interpolation is not zero or nearest - if kind != 'zero' and kind != 'nearest': + if kind != "zero" and kind != "nearest": # Fill in zero values with the last reported frequency # to avoid erroneous values when resampling frequencies_held = np.array(frequencies) @@ -272,40 +276,47 @@ def resample_melody_series(times, frequencies, voicing, if frequency == 0: frequencies_held[n + 1] = frequencies_held[n] # Linearly interpolate frequencies - frequencies_resampled = scipy.interpolate.interp1d(times, - frequencies_held, - kind)(times_new) + frequencies_resampled = scipy.interpolate.interp1d( + times, frequencies_held, kind + )(times_new) # Retain zeros - frequency_mask = scipy.interpolate.interp1d(times, - frequencies, - 'zero')(times_new) - frequencies_resampled *= (frequency_mask != 0) + frequency_mask = scipy.interpolate.interp1d(times, frequencies, "zero")( + times_new + ) + frequencies_resampled *= frequency_mask != 0 else: - frequencies_resampled = scipy.interpolate.interp1d(times, - frequencies, - kind)(times_new) + frequencies_resampled = scipy.interpolate.interp1d(times, frequencies, kind)( + times_new + ) # Use nearest-neighbor for voicing if it was used for frequencies # if voicing is not binary, use linear interpolation is_binary_voicing = np.all( - np.logical_or(np.equal(voicing, 0), np.equal(voicing, 1))) - if kind == 'nearest' or (kind == 'linear' and not is_binary_voicing): - voicing_resampled = scipy.interpolate.interp1d(times, - voicing, - kind)(times_new) + np.logical_or(np.equal(voicing, 0), np.equal(voicing, 1)) + ) + if kind == "nearest" or (kind == "linear" and not is_binary_voicing): + voicing_resampled = scipy.interpolate.interp1d(times, voicing, kind)(times_new) # otherwise, always use zeroth order else: - voicing_resampled = scipy.interpolate.interp1d(times, - voicing, - 'zero')(times_new) + voicing_resampled = scipy.interpolate.interp1d(times, voicing, "zero")( + times_new + ) return frequencies_resampled, voicing_resampled -def to_cent_voicing(ref_time, ref_freq, est_time, est_freq, - est_voicing=None, ref_reward=None, base_frequency=10., - hop=None, kind='linear'): - """Converts reference and estimated time/frequency (Hz) annotations to sampled +def to_cent_voicing( + ref_time, + ref_freq, + est_time, + est_freq, + est_voicing=None, + ref_reward=None, + base_frequency=10.0, + hop=None, + kind="linear", +): + """Convert reference and estimated time/frequency (Hz) annotations to sampled frequency (cent)/voicing arrays. A zero frequency indicates "unvoiced". @@ -379,23 +390,32 @@ def to_cent_voicing(ref_time, ref_freq, est_time, est_freq, if hop is not None: # Resample to common time base ref_cent, ref_voicing = resample_melody_series( - ref_time, ref_cent, ref_voicing, - constant_hop_timebase(hop, ref_time.max()), kind) + ref_time, + ref_cent, + ref_voicing, + constant_hop_timebase(hop, ref_time.max()), + kind, + ) est_cent, est_voicing = resample_melody_series( - est_time, est_cent, est_voicing, - constant_hop_timebase(hop, est_time.max()), kind) + est_time, + est_cent, + est_voicing, + constant_hop_timebase(hop, est_time.max()), + kind, + ) # Otherwise, only resample estimated to the reference time base else: est_cent, est_voicing = resample_melody_series( - est_time, est_cent, est_voicing, ref_time, kind) + est_time, est_cent, est_voicing, ref_time, kind + ) # ensure the estimated sequence is the same length as the reference len_diff = ref_cent.shape[0] - est_cent.shape[0] if len_diff >= 0: est_cent = np.append(est_cent, np.zeros(len_diff)) est_voicing = np.append(est_voicing, np.zeros(len_diff)) else: - est_cent = est_cent[:ref_cent.shape[0]] - est_voicing = est_voicing[:ref_voicing.shape[0]] + est_cent = est_cent[: ref_cent.shape[0]] + est_voicing = est_voicing[: ref_voicing.shape[0]] return (ref_voicing, ref_cent, est_voicing, est_cent) @@ -404,6 +424,7 @@ def voicing_recall(ref_voicing, est_voicing): """Compute the voicing recall given two voicing indicator sequences, one as reference (truth) and the other as the estimate (prediction). The sequences must be of the same length. + Examples -------- >>> ref_time, ref_freq = mir_eval.io.load_time_series('ref.txt') @@ -414,12 +435,14 @@ def voicing_recall(ref_voicing, est_voicing): ... est_time, ... est_freq) >>> recall = mir_eval.melody.voicing_recall(ref_v, est_v) + Parameters ---------- ref_voicing : np.ndarray Reference boolean voicing array est_voicing : np.ndarray Estimated boolean voicing array + Returns ------- vx_recall : float @@ -427,7 +450,7 @@ def voicing_recall(ref_voicing, est_voicing): indicated as voiced in est """ if ref_voicing.size == 0 or est_voicing.size == 0: - return 0. + return 0.0 ref_indicator = (ref_voicing > 0).astype(float) if np.sum(ref_indicator) == 0: return 1 @@ -438,6 +461,7 @@ def voicing_false_alarm(ref_voicing, est_voicing): """Compute the voicing false alarm rates given two voicing indicator sequences, one as reference (truth) and the other as the estimate (prediction). The sequences must be of the same length. + Examples -------- >>> ref_time, ref_freq = mir_eval.io.load_time_series('ref.txt') @@ -448,12 +472,14 @@ def voicing_false_alarm(ref_voicing, est_voicing): ... est_time, ... est_freq) >>> false_alarm = mir_eval.melody.voicing_false_alarm(ref_v, est_v) + Parameters ---------- ref_voicing : np.ndarray Reference boolean voicing array est_voicing : np.ndarray Estimated boolean voicing array + Returns ------- vx_false_alarm : float @@ -461,7 +487,7 @@ def voicing_false_alarm(ref_voicing, est_voicing): indicated as voiced in est """ if ref_voicing.size == 0 or est_voicing.size == 0: - return 0. + return 0.0 ref_indicator = (ref_voicing == 0).astype(float) if np.sum(ref_indicator) == 0: return 0 @@ -472,6 +498,7 @@ def voicing_measures(ref_voicing, est_voicing): """Compute the voicing recall and false alarm rates given two voicing indicator sequences, one as reference (truth) and the other as the estimate (prediction). The sequences must be of the same length. + Examples -------- >>> ref_time, ref_freq = mir_eval.io.load_time_series('ref.txt') @@ -483,12 +510,14 @@ def voicing_measures(ref_voicing, est_voicing): ... est_freq) >>> recall, false_alarm = mir_eval.melody.voicing_measures(ref_v, ... est_v) + Parameters ---------- ref_voicing : np.ndarray Reference boolean voicing array est_voicing : np.ndarray Estimated boolean voicing array + Returns ------- vx_recall : float @@ -504,8 +533,7 @@ def voicing_measures(ref_voicing, est_voicing): return vx_recall, vx_false_alm -def raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, - cent_tolerance=50): +def raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_tolerance=50): """Compute the raw pitch accuracy given two pitch (frequency) sequences in cents and matching voicing indicator sequences. The first pitch and voicing arrays are treated as the reference (truth), and the second two as the @@ -545,16 +573,18 @@ def raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, Raw pitch accuracy, the fraction of voiced frames in ref_cent for which est_cent provides a correct frequency values (within cent_tolerance cents). - """ - validate_voicing(ref_voicing, est_voicing) validate(ref_voicing, ref_cent, est_voicing, est_cent) # When input arrays are empty, return 0 by special case # If there are no voiced frames in reference, metric is 0 - if ref_voicing.size == 0 or ref_voicing.sum() == 0 \ - or ref_cent.size == 0 or est_cent.size == 0: - return 0. + if ( + ref_voicing.size == 0 + or ref_voicing.sum() == 0 + or ref_cent.size == 0 + or est_cent.size == 0 + ): + return 0.0 # Raw pitch = the number of voiced frames in the reference for which the # estimate provides a correct frequency value (within cent_tolerance cents) @@ -563,19 +593,17 @@ def raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, nonzero_freqs = np.logical_and(est_cent != 0, ref_cent != 0) if sum(nonzero_freqs) == 0: - return 0. + return 0.0 freq_diff_cents = np.abs(ref_cent - est_cent)[nonzero_freqs] correct_frequencies = freq_diff_cents < cent_tolerance - rpa = ( - np.sum(ref_voicing[nonzero_freqs] * correct_frequencies) / - np.sum(ref_voicing) - ) + rpa = np.sum(ref_voicing[nonzero_freqs] * correct_frequencies) / np.sum(ref_voicing) return rpa -def raw_chroma_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, - cent_tolerance=50): +def raw_chroma_accuracy( + ref_voicing, ref_cent, est_voicing, est_cent, cent_tolerance=50 +): """Compute the raw chroma accuracy given two pitch (frequency) sequences in cents and matching voicing indicator sequences. The first pitch and voicing arrays are treated as the reference (truth), and the second two as @@ -593,7 +621,6 @@ def raw_chroma_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, >>> raw_chroma = mir_eval.melody.raw_chroma_accuracy(ref_v, ref_c, ... est_v, est_c) - Parameters ---------- ref_voicing : np.ndarray @@ -622,28 +649,28 @@ def raw_chroma_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, validate(ref_voicing, ref_cent, est_voicing, est_cent) # When input arrays are empty, return 0 by special case # If there are no voiced frames in reference, metric is 0 - if ref_voicing.size == 0 or ref_voicing.sum() == 0 \ - or ref_cent.size == 0 or est_cent.size == 0: - return 0. + if ( + ref_voicing.size == 0 + or ref_voicing.sum() == 0 + or ref_cent.size == 0 + or est_cent.size == 0 + ): + return 0.0 # # Raw chroma = same as raw pitch except that octave errors are ignored. nonzero_freqs = np.logical_and(est_cent != 0, ref_cent != 0) if sum(nonzero_freqs) == 0: - return 0. + return 0.0 freq_diff_cents = np.abs(ref_cent - est_cent)[nonzero_freqs] octave = 1200.0 * np.floor(freq_diff_cents / 1200 + 0.5) correct_chroma = np.abs(freq_diff_cents - octave) < cent_tolerance - rca = ( - np.sum(ref_voicing[nonzero_freqs] * correct_chroma) / - np.sum(ref_voicing) - ) + rca = np.sum(ref_voicing[nonzero_freqs] * correct_chroma) / np.sum(ref_voicing) return rca -def overall_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, - cent_tolerance=50): +def overall_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_tolerance=50): """Compute the overall accuracy given two pitch (frequency) sequences in cents and matching voicing indicator sequences. The first pitch and voicing arrays are treated as the reference (truth), and the second two @@ -688,9 +715,13 @@ def overall_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, validate(ref_voicing, ref_cent, est_voicing, est_cent) # When input arrays are empty, return 0 by special case - if ref_voicing.size == 0 or est_voicing.size == 0 \ - or ref_cent.size == 0 or est_cent.size == 0: - return 0. + if ( + ref_voicing.size == 0 + or est_voicing.size == 0 + or ref_cent.size == 0 + or est_cent.size == 0 + ): + return 0.0 nonzero_freqs = np.logical_and(est_cent != 0, ref_cent != 0) freq_diff_cents = np.abs(ref_cent - est_cent)[nonzero_freqs] @@ -701,22 +732,26 @@ def overall_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, if np.sum(ref_voicing) == 0: ratio = 0.0 else: - ratio = (np.sum(ref_binary) / np.sum(ref_voicing)) + ratio = np.sum(ref_binary) / np.sum(ref_voicing) accuracy = ( ( - ratio * np.sum(ref_voicing[nonzero_freqs] * - est_voicing[nonzero_freqs] * - correct_frequencies) - ) + - np.sum((1.0 - ref_binary) * (1.0 - est_voicing)) + ratio + * np.sum( + ref_voicing[nonzero_freqs] + * est_voicing[nonzero_freqs] + * correct_frequencies + ) + ) + + np.sum((1.0 - ref_binary) * (1.0 - est_voicing)) ) / n_frames return accuracy -def evaluate(ref_time, ref_freq, est_time, est_freq, - est_voicing=None, ref_reward=None, **kwargs): +def evaluate( + ref_time, ref_freq, est_time, est_freq, est_voicing=None, ref_reward=None, **kwargs +): """Evaluate two melody (predominant f0) transcriptions, where the first is treated as the reference (ground truth) and the second as the estimate to be evaluated (prediction). @@ -746,7 +781,7 @@ def evaluate(ref_time, ref_freq, est_time, est_freq, ref_reward : np.ndarray Reference pitch estimation reward. Default None, which means all frames are weighted equally. - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -756,7 +791,6 @@ def evaluate(ref_time, ref_freq, est_time, est_freq, Dictionary of scores, where the key is the metric name (str) and the value is the (float) score achieved. - References ---------- .. [#] J. Salamon, E. Gomez, D. P. W. Ellis and G. Richard, "Melody @@ -764,7 +798,6 @@ def evaluate(ref_time, ref_freq, est_time, est_freq, and Challenges", IEEE Signal Processing Magazine, 31(2):118-134, Mar. 2014. - .. [#] G. E. Poliner, D. P. W. Ellis, A. F. Ehmann, E. Gomez, S. Streich, and B. Ong. "Melody transcription from music audio: Approaches and evaluation", IEEE Transactions on Audio, Speech, and @@ -776,34 +809,37 @@ def evaluate(ref_time, ref_freq, est_time, est_freq, """ # Convert to reference/estimated voicing/frequency (cent) arrays - (ref_voicing, ref_cent, - est_voicing, est_cent) = util.filter_kwargs( - to_cent_voicing, ref_time, ref_freq, est_time, est_freq, - est_voicing, ref_reward, **kwargs) + (ref_voicing, ref_cent, est_voicing, est_cent) = util.filter_kwargs( + to_cent_voicing, + ref_time, + ref_freq, + est_time, + est_freq, + est_voicing, + ref_reward, + **kwargs + ) # Compute metrics scores = collections.OrderedDict() - scores['Voicing Recall'] = util.filter_kwargs(voicing_recall, - ref_voicing, - est_voicing, **kwargs) + scores["Voicing Recall"] = util.filter_kwargs( + voicing_recall, ref_voicing, est_voicing, **kwargs + ) - scores['Voicing False Alarm'] = util.filter_kwargs(voicing_false_alarm, - ref_voicing, - est_voicing, **kwargs) + scores["Voicing False Alarm"] = util.filter_kwargs( + voicing_false_alarm, ref_voicing, est_voicing, **kwargs + ) - scores['Raw Pitch Accuracy'] = util.filter_kwargs(raw_pitch_accuracy, - ref_voicing, ref_cent, - est_voicing, est_cent, - **kwargs) + scores["Raw Pitch Accuracy"] = util.filter_kwargs( + raw_pitch_accuracy, ref_voicing, ref_cent, est_voicing, est_cent, **kwargs + ) - scores['Raw Chroma Accuracy'] = util.filter_kwargs(raw_chroma_accuracy, - ref_voicing, ref_cent, - est_voicing, est_cent, - **kwargs) + scores["Raw Chroma Accuracy"] = util.filter_kwargs( + raw_chroma_accuracy, ref_voicing, ref_cent, est_voicing, est_cent, **kwargs + ) - scores['Overall Accuracy'] = util.filter_kwargs(overall_accuracy, - ref_voicing, ref_cent, - est_voicing, est_cent, - **kwargs) + scores["Overall Accuracy"] = util.filter_kwargs( + overall_accuracy, ref_voicing, ref_cent, est_voicing, est_cent, **kwargs + ) return scores diff --git a/mir_eval/multipitch.py b/mir_eval/multipitch.py index 2d0f5150..bc74f3de 100644 --- a/mir_eval/multipitch.py +++ b/mir_eval/multipitch.py @@ -1,4 +1,4 @@ -''' +""" The goal of multiple f0 (multipitch) estimation and tracking is to identify all of the active fundamental frequencies in each time frame in a complex music signal. @@ -40,7 +40,7 @@ Signal Processing, 2007(1):154-163, Jan. 2007. .. [#bay2009] Bay, M., Ehmann, A. F., & Downie, J. S. (2009). Evaluation of Multiple-F0 Estimation and Tracking Systems. In ISMIR (pp. 315-320). -''' +""" import numpy as np import collections @@ -49,13 +49,13 @@ import warnings -MAX_TIME = 30000. # The maximum allowable time stamp (seconds) -MAX_FREQ = 5000. # The maximum allowable frequency (Hz) -MIN_FREQ = 20. # The minimum allowable frequency (Hz) +MAX_TIME = 30000.0 # The maximum allowable time stamp (seconds) +MAX_FREQ = 5000.0 # The maximum allowable frequency (Hz) +MIN_FREQ = 20.0 # The minimum allowable frequency (Hz) def validate(ref_time, ref_freqs, est_time, est_freqs): - """Checks that the time and frequency inputs are well-formed. + """Check that the time and frequency inputs are well-formed. Parameters ---------- @@ -67,9 +67,7 @@ def validate(ref_time, ref_freqs, est_time, est_freqs): estimate time stamps in seconds est_freqs : list of np.ndarray estimated frequencies in Hz - """ - util.validate_events(ref_time, max_time=MAX_TIME) util.validate_events(est_time, max_time=MAX_TIME) @@ -86,19 +84,19 @@ def validate(ref_time, ref_freqs, est_time, est_freqs): if len(est_freqs) == 0: warnings.warn("Estimated frequencies are empty.") if ref_time.size != len(ref_freqs): - raise ValueError('Reference times and frequencies have unequal ' - 'lengths.') + raise ValueError("Reference times and frequencies have unequal " "lengths.") if est_time.size != len(est_freqs): - raise ValueError('Estimate times and frequencies have unequal ' - 'lengths.') + raise ValueError("Estimate times and frequencies have unequal " "lengths.") for freq in ref_freqs: - util.validate_frequencies(freq, max_freq=MAX_FREQ, min_freq=MIN_FREQ, - allow_negatives=False) + util.validate_frequencies( + freq, max_freq=MAX_FREQ, min_freq=MIN_FREQ, allow_negatives=False + ) for freq in est_freqs: - util.validate_frequencies(freq, max_freq=MAX_FREQ, min_freq=MIN_FREQ, - allow_negatives=False) + util.validate_frequencies( + freq, max_freq=MAX_FREQ, min_freq=MIN_FREQ, allow_negatives=False + ) def resample_multipitch(times, frequencies, target_times): @@ -124,7 +122,7 @@ def resample_multipitch(times, frequencies, target_times): return [] if times.size == 0: - return [np.array([])]*len(target_times) + return [np.array([])] * len(target_times) n_times = len(frequencies) @@ -137,22 +135,26 @@ def resample_multipitch(times, frequencies, target_times): # since we're interpolating the index, fill_value is set to the first index # that is out of range. We handle this in the next line. new_frequency_index = scipy.interpolate.interp1d( - times, frequency_index, kind='nearest', bounds_error=False, - assume_sorted=True, fill_value=n_times)(target_times) + times, + frequency_index, + kind="nearest", + bounds_error=False, + assume_sorted=True, + fill_value=n_times, + )(target_times) # create array of frequencies plus additional empty element at the end for # target time stamps that are out of the interpolation range freq_vals = frequencies + [np.array([])] # map interpolated indices back to frequency values - frequencies_resampled = [ - freq_vals[i] for i in new_frequency_index.astype(int)] + frequencies_resampled = [freq_vals[i] for i in new_frequency_index.astype(int)] return frequencies_resampled def frequencies_to_midi(frequencies, ref_frequency=440.0): - """Converts frequencies to continuous MIDI values. + """Convert frequencies to continuous MIDI values. Parameters ---------- @@ -166,7 +168,7 @@ def frequencies_to_midi(frequencies, ref_frequency=440.0): frequencies_midi : list of np.ndarray Continuous MIDI frequency values. """ - return [69.0 + 12.0*np.log2(freqs/ref_frequency) for freqs in frequencies] + return [69.0 + 12.0 * np.log2(freqs / ref_frequency) for freqs in frequencies] def midi_to_chroma(frequencies_midi): @@ -187,7 +189,7 @@ def midi_to_chroma(frequencies_midi): def compute_num_freqs(frequencies): - """Computes the number of frequencies for each time point. + """Compute the number of frequencies for each time point. Parameters ---------- @@ -227,14 +229,14 @@ def compute_num_true_positives(ref_freqs, est_freqs, window=0.5, chroma=False): """ n_frames = len(ref_freqs) - true_positives = np.zeros((n_frames, )) + true_positives = np.zeros((n_frames,)) for i, (ref_frame, est_frame) in enumerate(zip(ref_freqs, est_freqs)): if chroma: # match chroma-wrapped frequency events matching = util.match_events( - ref_frame, est_frame, window, - distance=util._outer_distance_mod_n) + ref_frame, est_frame, window, distance=util._outer_distance_mod_n + ) else: # match frequency events within tolerance window in semitones matching = util.match_events(ref_frame, est_frame, window) @@ -271,21 +273,21 @@ def compute_accuracy(true_positives, n_ref, n_est): n_est_sum = n_est.sum() if n_est_sum > 0: - precision = true_positive_sum/n_est.sum() + precision = true_positive_sum / n_est.sum() else: warnings.warn("Estimate frequencies are all empty.") precision = 0.0 n_ref_sum = n_ref.sum() if n_ref_sum > 0: - recall = true_positive_sum/n_ref.sum() + recall = true_positive_sum / n_ref.sum() else: warnings.warn("Reference frequencies are all empty.") recall = 0.0 acc_denom = (n_est + n_ref - true_positives).sum() if acc_denom > 0: - acc = true_positive_sum/acc_denom + acc = true_positive_sum / acc_denom else: acc = 0.0 @@ -321,25 +323,25 @@ def compute_err_score(true_positives, n_ref, n_est): if n_ref_sum == 0: warnings.warn("Reference frequencies are all empty.") - return 0., 0., 0., 0. + return 0.0, 0.0, 0.0, 0.0 # Substitution error - e_sub = (np.min([n_ref, n_est], axis=0) - true_positives).sum()/n_ref_sum + e_sub = (np.min([n_ref, n_est], axis=0) - true_positives).sum() / n_ref_sum # compute the max of (n_ref - n_est) and 0 e_miss_numerator = n_ref - n_est e_miss_numerator[e_miss_numerator < 0] = 0 # Miss error - e_miss = e_miss_numerator.sum()/n_ref_sum + e_miss = e_miss_numerator.sum() / n_ref_sum # compute the max of (n_est - n_ref) and 0 e_fa_numerator = n_est - n_ref e_fa_numerator[e_fa_numerator < 0] = 0 # False alarm error - e_fa = e_fa_numerator.sum()/n_ref_sum + e_fa = e_fa_numerator.sum() / n_ref_sum # total error - e_tot = (np.max([n_ref, n_est], axis=0) - true_positives).sum()/n_ref_sum + e_tot = (np.max([n_ref, n_est], axis=0) - true_positives).sum() / n_ref_sum return e_sub, e_miss, e_fa, e_tot @@ -368,7 +370,7 @@ def metrics(ref_time, ref_freqs, est_time, est_freqs, **kwargs): Time of each estimated frequency value est_freqs : list of np.ndarray List of np.ndarrays of estimate frequency values - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -408,8 +410,10 @@ def metrics(ref_time, ref_freqs, est_time, est_freqs, **kwargs): # resample est_freqs if est_times is different from ref_times if est_time.size != ref_time.size or not np.allclose(est_time, ref_time): - warnings.warn("Estimate times not equal to reference times. " - "Resampling to common time base.") + warnings.warn( + "Estimate times not equal to reference times. " + "Resampling to common time base." + ) est_freqs = resample_multipitch(est_time, est_freqs, ref_time) # convert frequencies from Hz to continuous midi note number @@ -420,38 +424,56 @@ def metrics(ref_time, ref_freqs, est_time, est_freqs, **kwargs): ref_freqs_chroma = midi_to_chroma(ref_freqs_midi) est_freqs_chroma = midi_to_chroma(est_freqs_midi) - # count number of occurences + # count number of occurrences n_ref = compute_num_freqs(ref_freqs_midi) n_est = compute_num_freqs(est_freqs_midi) # compute the number of true positives true_positives = util.filter_kwargs( - compute_num_true_positives, ref_freqs_midi, est_freqs_midi, **kwargs) + compute_num_true_positives, ref_freqs_midi, est_freqs_midi, **kwargs + ) # compute the number of true positives ignoring octave mistakes true_positives_chroma = util.filter_kwargs( - compute_num_true_positives, ref_freqs_chroma, - est_freqs_chroma, chroma=True, **kwargs) + compute_num_true_positives, + ref_freqs_chroma, + est_freqs_chroma, + chroma=True, + **kwargs + ) # compute accuracy metrics - precision, recall, accuracy = compute_accuracy( - true_positives, n_ref, n_est) + precision, recall, accuracy = compute_accuracy(true_positives, n_ref, n_est) # compute error metrics - e_sub, e_miss, e_fa, e_tot = compute_err_score( - true_positives, n_ref, n_est) + e_sub, e_miss, e_fa, e_tot = compute_err_score(true_positives, n_ref, n_est) # compute accuracy metrics ignoring octave mistakes precision_chroma, recall_chroma, accuracy_chroma = compute_accuracy( - true_positives_chroma, n_ref, n_est) + true_positives_chroma, n_ref, n_est + ) # compute error metrics ignoring octave mistakes e_sub_chroma, e_miss_chroma, e_fa_chroma, e_tot_chroma = compute_err_score( - true_positives_chroma, n_ref, n_est) - - return (precision, recall, accuracy, e_sub, e_miss, e_fa, e_tot, - precision_chroma, recall_chroma, accuracy_chroma, e_sub_chroma, - e_miss_chroma, e_fa_chroma, e_tot_chroma) + true_positives_chroma, n_ref, n_est + ) + + return ( + precision, + recall, + accuracy, + e_sub, + e_miss, + e_fa, + e_tot, + precision_chroma, + recall_chroma, + accuracy_chroma, + e_sub_chroma, + e_miss_chroma, + e_fa_chroma, + e_tot_chroma, + ) def evaluate(ref_time, ref_freqs, est_time, est_freqs, **kwargs): @@ -476,7 +498,7 @@ def evaluate(ref_time, ref_freqs, est_time, est_freqs, **kwargs): Time of each estimated frequency value est_freqs : list of np.ndarray List of np.ndarrays of estimate frequency values - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -489,20 +511,21 @@ def evaluate(ref_time, ref_freqs, est_time, est_freqs, **kwargs): """ scores = collections.OrderedDict() - (scores['Precision'], - scores['Recall'], - scores['Accuracy'], - scores['Substitution Error'], - scores['Miss Error'], - scores['False Alarm Error'], - scores['Total Error'], - scores['Chroma Precision'], - scores['Chroma Recall'], - scores['Chroma Accuracy'], - scores['Chroma Substitution Error'], - scores['Chroma Miss Error'], - scores['Chroma False Alarm Error'], - scores['Chroma Total Error']) = util.filter_kwargs( - metrics, ref_time, ref_freqs, est_time, est_freqs, **kwargs) + ( + scores["Precision"], + scores["Recall"], + scores["Accuracy"], + scores["Substitution Error"], + scores["Miss Error"], + scores["False Alarm Error"], + scores["Total Error"], + scores["Chroma Precision"], + scores["Chroma Recall"], + scores["Chroma Accuracy"], + scores["Chroma Substitution Error"], + scores["Chroma Miss Error"], + scores["Chroma False Alarm Error"], + scores["Chroma Total Error"], + ) = util.filter_kwargs(metrics, ref_time, ref_freqs, est_time, est_freqs, **kwargs) return scores diff --git a/mir_eval/onset.py b/mir_eval/onset.py index 606ad0d7..d3437a33 100644 --- a/mir_eval/onset.py +++ b/mir_eval/onset.py @@ -1,4 +1,4 @@ -''' +""" The goal of an onset detection algorithm is to automatically determine when notes are played in a piece of music. The primary method used to evaluate onset detectors is to first determine which estimated onsets are "correct", @@ -19,9 +19,9 @@ ------- * :func:`mir_eval.onset.f_measure`: Precision, Recall, and F-measure scores - based on the number of esimated onsets which are sufficiently close to + based on the number of estimated onsets which are sufficiently close to reference onsets. -''' +""" import collections from . import util @@ -29,11 +29,11 @@ # The maximum allowable beat time -MAX_TIME = 30000. +MAX_TIME = 30000.0 def validate(reference_onsets, estimated_onsets): - """Checks that the input annotations to a metric look like valid onset time + """Check that the input annotations to a metric look like valid onset time arrays, and throws helpful errors if not. Parameters @@ -53,9 +53,9 @@ def validate(reference_onsets, estimated_onsets): util.validate_events(onsets, MAX_TIME) -def f_measure(reference_onsets, estimated_onsets, window=.05): +def f_measure(reference_onsets, estimated_onsets, window=0.05): """Compute the F-measure of correct vs incorrectly predicted onsets. - "Corectness" is determined over a small window. + "Correctness" is determined over a small window. Examples -------- @@ -87,13 +87,13 @@ def f_measure(reference_onsets, estimated_onsets, window=.05): validate(reference_onsets, estimated_onsets) # If either list is empty, return 0s if reference_onsets.size == 0 or estimated_onsets.size == 0: - return 0., 0., 0. + return 0.0, 0.0, 0.0 # Compute the best-case matching between reference and estimated onset # locations matching = util.match_events(reference_onsets, estimated_onsets, window) - precision = float(len(matching))/len(estimated_onsets) - recall = float(len(matching))/len(reference_onsets) + precision = float(len(matching)) / len(estimated_onsets) + recall = float(len(matching)) / len(reference_onsets) # Compute F-measure and return all statistics return util.f_measure(precision, recall), precision, recall @@ -114,7 +114,7 @@ def evaluate(reference_onsets, estimated_onsets, **kwargs): reference onset locations, in seconds estimated_onsets : np.ndarray estimated onset locations, in seconds - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -128,9 +128,8 @@ def evaluate(reference_onsets, estimated_onsets, **kwargs): # Compute all metrics scores = collections.OrderedDict() - (scores['F-measure'], - scores['Precision'], - scores['Recall']) = util.filter_kwargs(f_measure, reference_onsets, - estimated_onsets, **kwargs) + (scores["F-measure"], scores["Precision"], scores["Recall"]) = util.filter_kwargs( + f_measure, reference_onsets, estimated_onsets, **kwargs + ) return scores diff --git a/mir_eval/pattern.py b/mir_eval/pattern.py index b5c016d5..1dc75436 100644 --- a/mir_eval/pattern.py +++ b/mir_eval/pattern.py @@ -16,7 +16,7 @@ The input format can be automatically generated by calling :func:`mir_eval.io.load_patterns`. This format is a list of a list of tuples. The first list collections patterns, each of which is a list of -occurences, and each occurrence is a list of MIDI onset tuples of +occurrences, and each occurrence is a list of MIDI onset tuples of ``(onset_time, mid_note)`` A pattern is a list of occurrences. The first occurrence must be the prototype @@ -54,7 +54,6 @@ relevance. """ - import numpy as np from . import util import warnings @@ -62,11 +61,11 @@ def _n_onset_midi(patterns): - """Computes the number of onset_midi objects in a pattern + """Compute the number of onset_midi objects in a pattern Parameters ---------- - patterns : + patterns A list of patterns using the format returned by :func:`mir_eval.io.load_patterns()` @@ -80,7 +79,7 @@ def _n_onset_midi(patterns): def validate(reference_patterns, estimated_patterns): - """Checks that the input annotations to a metric look like valid pattern + """Check that the input annotations to a metric look like valid pattern lists, and throws helpful errors if not. Parameters @@ -90,30 +89,29 @@ def validate(reference_patterns, estimated_patterns): :func:`mir_eval.io.load_patterns()` estimated_patterns : list The estimated patterns in the same format - - Returns - ------- - """ # Warn if pattern lists are empty if _n_onset_midi(reference_patterns) == 0: - warnings.warn('Reference patterns are empty.') + warnings.warn("Reference patterns are empty.") if _n_onset_midi(estimated_patterns) == 0: - warnings.warn('Estimated patterns are empty.') + warnings.warn("Estimated patterns are empty.") for patterns in [reference_patterns, estimated_patterns]: for pattern in patterns: if len(pattern) <= 0: - raise ValueError("Each pattern must contain at least one " - "occurrence.") + raise ValueError( + "Each pattern must contain at least one " "occurrence." + ) for occurrence in pattern: for onset_midi in occurrence: if len(onset_midi) != 2: - raise ValueError("The (onset, midi) tuple must " - "contain exactly 2 elements.") + raise ValueError( + "The (onset, midi) tuple must " + "contain exactly 2 elements." + ) def _occurrence_intersection(occ_P, occ_Q): - """Computes the intersection between two occurrences. + """Compute the intersection between two occurrences. Parameters ---------- @@ -130,11 +128,11 @@ def _occurrence_intersection(occ_P, occ_Q): """ set_P = set([tuple(onset_midi) for onset_midi in occ_P]) set_Q = set([tuple(onset_midi) for onset_midi in occ_Q]) - return set_P & set_Q # Return the intersection + return set_P & set_Q # Return the intersection def _compute_score_matrix(P, Q, similarity_metric="cardinality_score"): - """Computes the score matrix between the patterns P and Q. + """Compute the score matrix between the patterns P and Q. Parameters ---------- @@ -155,28 +153,28 @@ def _compute_score_matrix(P, Q, similarity_metric="cardinality_score"): The score matrix between P and Q using the similarity_metric. """ - sm = np.zeros((len(P), len(Q))) # The score matrix + sm = np.zeros((len(P), len(Q))) # The score matrix for iP, occ_P in enumerate(P): for iQ, occ_Q in enumerate(Q): if similarity_metric == "cardinality_score": denom = float(np.max([len(occ_P), len(occ_Q)])) # Compute the score - sm[iP, iQ] = len(_occurrence_intersection(occ_P, occ_Q)) / \ - denom + sm[iP, iQ] = len(_occurrence_intersection(occ_P, occ_Q)) / denom # TODO: More scores: 'normalised matching socre' else: - raise ValueError("The similarity metric (%s) can only be: " - "'cardinality_score'.") + raise ValueError( + "The similarity metric (%s) can only be: " "'cardinality_score'." + ) return sm def standard_FPR(reference_patterns, estimated_patterns, tol=1e-5): - """Standard F1 Score, Precision and Recall. + """Compute the standard F1 Score, Precision and Recall. This metric checks if the prototype patterns of the reference match possible translated patterns in the prototype patterns of the estimations. Since the sizes of these prototypes must be equal, this metric is quite - restictive and it tends to be 0 in most of 2013 MIREX results. + restrictive and it tends to be 0 in most of 2013 MIREX results. Examples -------- @@ -208,18 +206,17 @@ def standard_FPR(reference_patterns, estimated_patterns, tol=1e-5): """ validate(reference_patterns, estimated_patterns) - nP = len(reference_patterns) # Number of patterns in the reference - nQ = len(estimated_patterns) # Number of patterns in the estimation - k = 0 # Number of patterns that match + nP = len(reference_patterns) # Number of patterns in the reference + nQ = len(estimated_patterns) # Number of patterns in the estimation + k = 0 # Number of patterns that match # If no patterns were provided, metric is zero - if _n_onset_midi(reference_patterns) == 0 or \ - _n_onset_midi(estimated_patterns) == 0: - return 0., 0., 0. + if _n_onset_midi(reference_patterns) == 0 or _n_onset_midi(estimated_patterns) == 0: + return 0.0, 0.0, 0.0 # Find matches of the prototype patterns for ref_pattern in reference_patterns: - P = np.asarray(ref_pattern[0]) # Get reference prototype + P = np.asarray(ref_pattern[0]) # Get reference prototype for est_pattern in estimated_patterns: Q = np.asarray(est_pattern[0]) # Get estimation prototype @@ -227,8 +224,7 @@ def standard_FPR(reference_patterns, estimated_patterns, tol=1e-5): continue # Check transposition given a certain tolerance - if (len(P) == len(Q) == 1 or - np.max(np.abs(np.diff(P - Q, axis=0))) < tol): + if len(P) == len(Q) == 1 or np.max(np.abs(np.diff(P - Q, axis=0))) < tol: k += 1 break @@ -239,9 +235,10 @@ def standard_FPR(reference_patterns, estimated_patterns, tol=1e-5): return f_measure, precision, recall -def establishment_FPR(reference_patterns, estimated_patterns, - similarity_metric="cardinality_score"): - """Establishment F1 Score, Precision and Recall. +def establishment_FPR( + reference_patterns, estimated_patterns, similarity_metric="cardinality_score" +): + """Compute the establishment F1 Score, Precision and Recall. Examples -------- @@ -250,7 +247,6 @@ def establishment_FPR(reference_patterns, estimated_patterns, >>> F, P, R = mir_eval.pattern.establishment_FPR(ref_patterns, ... est_patterns) - Parameters ---------- reference_patterns : list @@ -269,7 +265,6 @@ def establishment_FPR(reference_patterns, estimated_patterns, (Default value = "cardinality_score") - Returns ------- f_measure : float @@ -281,19 +276,17 @@ def establishment_FPR(reference_patterns, estimated_patterns, """ validate(reference_patterns, estimated_patterns) - nP = len(reference_patterns) # Number of elements in reference - nQ = len(estimated_patterns) # Number of elements in estimation - S = np.zeros((nP, nQ)) # Establishment matrix + nP = len(reference_patterns) # Number of elements in reference + nQ = len(estimated_patterns) # Number of elements in estimation + S = np.zeros((nP, nQ)) # Establishment matrix # If no patterns were provided, metric is zero - if _n_onset_midi(reference_patterns) == 0 or \ - _n_onset_midi(estimated_patterns) == 0: - return 0., 0., 0. + if _n_onset_midi(reference_patterns) == 0 or _n_onset_midi(estimated_patterns) == 0: + return 0.0, 0.0, 0.0 for iP, ref_pattern in enumerate(reference_patterns): for iQ, est_pattern in enumerate(estimated_patterns): - s = _compute_score_matrix(ref_pattern, est_pattern, - similarity_metric) + s = _compute_score_matrix(ref_pattern, est_pattern, similarity_metric) S[iP, iQ] = np.max(s) # Compute scores @@ -303,10 +296,13 @@ def establishment_FPR(reference_patterns, estimated_patterns, return f_measure, precision, recall -def occurrence_FPR(reference_patterns, estimated_patterns, thres=.75, - similarity_metric="cardinality_score"): - """Establishment F1 Score, Precision and Recall. - +def occurrence_FPR( + reference_patterns, + estimated_patterns, + thres=0.75, + similarity_metric="cardinality_score", +): + """Compute the occurrence F1 Score, Precision and Recall. Examples -------- @@ -315,18 +311,20 @@ def occurrence_FPR(reference_patterns, estimated_patterns, thres=.75, >>> F, P, R = mir_eval.pattern.occurrence_FPR(ref_patterns, ... est_patterns) - Parameters ---------- reference_patterns : list The reference patterns in the format returned by :func:`mir_eval.io.load_patterns()` + estimated_patterns : list The estimated patterns in the same format + thres : float - How similar two occcurrences must be in order to be considered + How similar two occurrences must be in order to be considered equal (Default value = .75) + similarity_metric : str A string representing the metric to be used when computing the similarity matrix. Accepted values: @@ -336,16 +334,14 @@ def occurrence_FPR(reference_patterns, estimated_patterns, thres=.75, (Default value = "cardinality_score") - Returns ------- f_measure : float - The establishment F1 Score + The occurrence F1 Score precision : float - The establishment Precision + The occurrence Precision recall : float - The establishment Recall - + The occurrence Recall """ validate(reference_patterns, estimated_patterns) # Number of elements in reference @@ -359,14 +355,12 @@ def occurrence_FPR(reference_patterns, estimated_patterns, thres=.75, rel_idx = np.empty((0, 2), dtype=int) # If no patterns were provided, metric is zero - if _n_onset_midi(reference_patterns) == 0 or \ - _n_onset_midi(estimated_patterns) == 0: - return 0., 0., 0. + if _n_onset_midi(reference_patterns) == 0 or _n_onset_midi(estimated_patterns) == 0: + return 0.0, 0.0, 0.0 for iP, ref_pattern in enumerate(reference_patterns): for iQ, est_pattern in enumerate(estimated_patterns): - s = _compute_score_matrix(ref_pattern, est_pattern, - similarity_metric) + s = _compute_score_matrix(ref_pattern, est_pattern, similarity_metric) if np.max(s) >= thres: O_PR[iP, iQ, 0] = np.mean(np.max(s, axis=0)) O_PR[iP, iQ, 1] = np.mean(np.max(s, axis=1)) @@ -378,11 +372,9 @@ def occurrence_FPR(reference_patterns, estimated_patterns, thres=.75, recall = 0 else: P = O_PR[:, :, 0] - precision = np.mean(np.max(P[np.ix_(rel_idx[:, 0], rel_idx[:, 1])], - axis=0)) + precision = np.mean(np.max(P[np.ix_(rel_idx[:, 0], rel_idx[:, 1])], axis=0)) R = O_PR[:, :, 1] - recall = np.mean(np.max(R[np.ix_(rel_idx[:, 0], rel_idx[:, 1])], - axis=1)) + recall = np.mean(np.max(R[np.ix_(rel_idx[:, 0], rel_idx[:, 1])], axis=1)) f_measure = util.f_measure(precision, recall) return f_measure, precision, recall @@ -418,20 +410,19 @@ def three_layer_FPR(reference_patterns, estimated_patterns): validate(reference_patterns, estimated_patterns) def compute_first_layer_PR(ref_occs, est_occs): - """Computes the first layer Precision and Recall values given the + """Compute the first layer Precision and Recall values given the set of occurrences in the reference and the set of occurrences in the estimation. Parameters ---------- - ref_occs : - - est_occs : - + ref_occs + est_occs Returns ------- - + precision + recall """ # Find the length of the intersection between reference and estimation s = len(_occurrence_intersection(ref_occs, est_occs)) @@ -442,20 +433,19 @@ def compute_first_layer_PR(ref_occs, est_occs): return precision, recall def compute_second_layer_PR(ref_pattern, est_pattern): - """Computes the second layer Precision and Recall values given the + """Compute the second layer Precision and Recall values given the set of occurrences in the reference and the set of occurrences in the estimation. Parameters ---------- - ref_pattern : - - est_pattern : - + ref_pattern + est_pattern Returns ------- - + precision + recall """ # Compute the first layer scores F_1 = compute_layer(ref_pattern, est_pattern) @@ -466,8 +456,8 @@ def compute_second_layer_PR(ref_pattern, est_pattern): return precision, recall def compute_layer(ref_elements, est_elements, layer=1): - """Computes the F-measure matrix for a given layer. The reference and - estimated elements can be either patters or occurrences, depending + """Compute the F-measure matrix for a given layer. The reference and + estimated elements can be either patterns or occurrences, depending on the layer. For layer 1, the elements must be occurrences. @@ -475,24 +465,21 @@ def compute_layer(ref_elements, est_elements, layer=1): Parameters ---------- - ref_elements : - - est_elements : - - layer : - (Default value = 1) + ref_elements + est_elements + layer + (Default value = 1) Returns ------- - + F : F-measure for the given layer """ if layer != 1 and layer != 2: - raise ValueError("Layer (%d) must be an integer between 1 and 2" - % layer) + raise ValueError("Layer (%d) must be an integer between 1 and 2" % layer) - nP = len(ref_elements) # Number of elements in reference - nQ = len(est_elements) # Number of elements in estimation - F = np.zeros((nP, nQ)) # F-measure matrix for the given layer + nP = len(ref_elements) # Number of elements in reference + nQ = len(est_elements) # Number of elements in estimation + F = np.zeros((nP, nQ)) # F-measure matrix for the given layer for iP in range(nP): for iQ in range(nQ): if layer == 1: @@ -506,9 +493,8 @@ def compute_layer(ref_elements, est_elements, layer=1): return F # If no patterns were provided, metric is zero - if _n_onset_midi(reference_patterns) == 0 or \ - _n_onset_midi(estimated_patterns) == 0: - return 0., 0., 0. + if _n_onset_midi(reference_patterns) == 0 or _n_onset_midi(estimated_patterns) == 0: + return 0.0, 0.0, 0.0 # Compute the second layer (it includes the first layer) F_2 = compute_layer(reference_patterns, estimated_patterns, layer=2) @@ -550,22 +536,19 @@ def first_n_three_layer_P(reference_patterns, estimated_patterns, n=5): ------- precision : float The first n three-layer Precision - """ - validate(reference_patterns, estimated_patterns) # If no patterns were provided, metric is zero - if _n_onset_midi(reference_patterns) == 0 or \ - _n_onset_midi(estimated_patterns) == 0: - return 0., 0., 0. + if _n_onset_midi(reference_patterns) == 0 or _n_onset_midi(estimated_patterns) == 0: + return 0.0, 0.0, 0.0 # Get only the first n patterns from the estimated results - fn_est_patterns = estimated_patterns[:min(len(estimated_patterns), n)] + fn_est_patterns = estimated_patterns[: min(len(estimated_patterns), n)] # Compute the three-layer scores for the first n estimated patterns F, P, R = three_layer_FPR(reference_patterns, fn_est_patterns) - return P # Return the precision only + return P # Return the precision only def first_n_target_proportion_R(reference_patterns, estimated_patterns, n=5): @@ -598,17 +581,14 @@ def first_n_target_proportion_R(reference_patterns, estimated_patterns, n=5): ------- recall : float The first n target proportion Recall. - """ - validate(reference_patterns, estimated_patterns) # If no patterns were provided, metric is zero - if _n_onset_midi(reference_patterns) == 0 or \ - _n_onset_midi(estimated_patterns) == 0: - return 0., 0., 0. + if _n_onset_midi(reference_patterns) == 0 or _n_onset_midi(estimated_patterns) == 0: + return 0.0, 0.0, 0.0 # Get only the first n patterns from the estimated results - fn_est_patterns = estimated_patterns[:min(len(estimated_patterns), n)] + fn_est_patterns = estimated_patterns[: min(len(estimated_patterns), n)] F, P, R = establishment_FPR(reference_patterns, fn_est_patterns) return R @@ -630,7 +610,7 @@ def evaluate(ref_patterns, est_patterns, **kwargs): :func:`mir_eval.io.load_patterns()` est_patterns : list The estimated patterns in the same format - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -639,45 +619,45 @@ def evaluate(ref_patterns, est_patterns, **kwargs): scores : dict Dictionary of scores, where the key is the metric name (str) and the value is the (float) score achieved. - """ - # Compute all the metrics scores = collections.OrderedDict() # Standard scores - scores['F'], scores['P'], scores['R'] = \ - util.filter_kwargs(standard_FPR, ref_patterns, est_patterns, **kwargs) + scores["F"], scores["P"], scores["R"] = util.filter_kwargs( + standard_FPR, ref_patterns, est_patterns, **kwargs + ) # Establishment scores - scores['F_est'], scores['P_est'], scores['R_est'] = \ - util.filter_kwargs(establishment_FPR, ref_patterns, est_patterns, - **kwargs) + scores["F_est"], scores["P_est"], scores["R_est"] = util.filter_kwargs( + establishment_FPR, ref_patterns, est_patterns, **kwargs + ) # Occurrence scores # Force these values for thresh - kwargs['thresh'] = .5 - scores['F_occ.5'], scores['P_occ.5'], scores['R_occ.5'] = \ - util.filter_kwargs(occurrence_FPR, ref_patterns, est_patterns, - **kwargs) - kwargs['thresh'] = .75 - scores['F_occ.75'], scores['P_occ.75'], scores['R_occ.75'] = \ - util.filter_kwargs(occurrence_FPR, ref_patterns, est_patterns, - **kwargs) + kwargs["thresh"] = 0.5 + scores["F_occ.5"], scores["P_occ.5"], scores["R_occ.5"] = util.filter_kwargs( + occurrence_FPR, ref_patterns, est_patterns, **kwargs + ) + kwargs["thresh"] = 0.75 + scores["F_occ.75"], scores["P_occ.75"], scores["R_occ.75"] = util.filter_kwargs( + occurrence_FPR, ref_patterns, est_patterns, **kwargs + ) # Three-layer scores - scores['F_3'], scores['P_3'], scores['R_3'] = \ - util.filter_kwargs(three_layer_FPR, ref_patterns, est_patterns, - **kwargs) + scores["F_3"], scores["P_3"], scores["R_3"] = util.filter_kwargs( + three_layer_FPR, ref_patterns, est_patterns, **kwargs + ) # First Five Patterns scores # Set default value of n - if 'n' not in kwargs: - kwargs['n'] = 5 - scores['FFP'] = util.filter_kwargs(first_n_three_layer_P, ref_patterns, - est_patterns, **kwargs) - scores['FFTP_est'] = \ - util.filter_kwargs(first_n_target_proportion_R, ref_patterns, - est_patterns, **kwargs) + if "n" not in kwargs: + kwargs["n"] = 5 + scores["FFP"] = util.filter_kwargs( + first_n_three_layer_P, ref_patterns, est_patterns, **kwargs + ) + scores["FFTP_est"] = util.filter_kwargs( + first_n_target_proportion_R, ref_patterns, est_patterns, **kwargs + ) return scores diff --git a/mir_eval/segment.py b/mir_eval/segment.py index d22edb7d..7a49d6ff 100644 --- a/mir_eval/segment.py +++ b/mir_eval/segment.py @@ -1,5 +1,5 @@ # CREATED:2013-08-13 12:02:42 by Brian McFee -''' +""" Evaluation criteria for structural segmentation fall into two categories: boundary annotation and structural annotation. Boundary annotation is the task of predicting the times at which structural changes occur, such as when a verse @@ -70,7 +70,7 @@ V-Measure: A Conditional Entropy-Based External Cluster Evaluation Measure. In EMNLP-CoNLL (Vol. 7, pp. 410-420). -''' +""" import collections import warnings @@ -85,7 +85,7 @@ def validate_boundary(reference_intervals, estimated_intervals, trim): - """Checks that the input annotations to a segment boundary estimation + """Check that the input annotations to a segment boundary estimation metric (i.e. one that only takes in segment intervals) look like valid segment times, and throws helpful errors if not. @@ -95,17 +95,13 @@ def validate_boundary(reference_intervals, estimated_intervals, trim): reference segment intervals, in the format returned by :func:`mir_eval.io.load_intervals` or :func:`mir_eval.io.load_labeled_intervals`. - estimated_intervals : np.ndarray, shape=(m, 2) estimated segment intervals, in the format returned by :func:`mir_eval.io.load_intervals` or :func:`mir_eval.io.load_labeled_intervals`. - trim : bool will the start and end events be trimmed? - """ - if trim: # If we're trimming, then we need at least 2 intervals min_size = 2 @@ -123,9 +119,10 @@ def validate_boundary(reference_intervals, estimated_intervals, trim): util.validate_intervals(intervals) -def validate_structure(reference_intervals, reference_labels, - estimated_intervals, estimated_labels): - """Checks that the input annotations to a structure estimation metric (i.e. +def validate_structure( + reference_intervals, reference_labels, estimated_intervals, estimated_labels +): + """Check that the input annotations to a structure estimation metric (i.e. one that takes in both segment boundaries and their labels) look like valid segment times and labels, and throws helpful errors if not. @@ -134,33 +131,30 @@ def validate_structure(reference_intervals, reference_labels, reference_intervals : np.ndarray, shape=(n, 2) reference segment intervals, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. - reference_labels : list, shape=(n,) reference segment labels, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. - estimated_intervals : np.ndarray, shape=(m, 2) estimated segment intervals, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. - estimated_labels : list, shape=(m,) estimated segment labels, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. """ - for (intervals, labels) in [(reference_intervals, reference_labels), - (estimated_intervals, estimated_labels)]: - + for intervals, labels in [ + (reference_intervals, reference_labels), + (estimated_intervals, estimated_labels), + ]: util.validate_intervals(intervals) if intervals.shape[0] != len(labels): - raise ValueError('Number of intervals does not match number ' - 'of labels') + raise ValueError("Number of intervals does not match number " "of labels") # Check only when intervals are non-empty if intervals.size > 0: # Make sure intervals start at 0 if not np.allclose(intervals.min(), 0.0): - raise ValueError('Segment intervals do not start at 0') + raise ValueError("Segment intervals do not start at 0") if reference_intervals.size == 0: warnings.warn("Reference intervals are empty.") @@ -168,13 +162,13 @@ def validate_structure(reference_intervals, reference_labels, warnings.warn("Estimated intervals are empty.") # Check only when intervals are non-empty if reference_intervals.size > 0 and estimated_intervals.size > 0: - if not np.allclose(reference_intervals.max(), - estimated_intervals.max()): - raise ValueError('End times do not match') + if not np.allclose(reference_intervals.max(), estimated_intervals.max()): + raise ValueError("End times do not match") -def detection(reference_intervals, estimated_intervals, - window=0.5, beta=1.0, trim=False): +def detection( + reference_intervals, estimated_intervals, window=0.5, beta=1.0, trim=False +): """Boundary detection hit-rate. A hit is counted whenever an reference boundary is within ``window`` of a @@ -230,9 +224,7 @@ def detection(reference_intervals, estimated_intervals, recall of reference reference boundaries f_measure : float F-measure (weighted harmonic mean of ``precision`` and ``recall``) - """ - validate_boundary(reference_intervals, estimated_intervals, trim) # Convert intervals to boundaries @@ -248,9 +240,7 @@ def detection(reference_intervals, estimated_intervals, if len(reference_boundaries) == 0 or len(estimated_boundaries) == 0: return 0.0, 0.0, 0.0 - matching = util.match_events(reference_boundaries, - estimated_boundaries, - window) + matching = util.match_events(reference_boundaries, estimated_boundaries, window) precision = float(len(matching)) / len(estimated_boundaries) recall = float(len(matching)) / len(reference_boundaries) @@ -294,9 +284,7 @@ def deviation(reference_intervals, estimated_intervals, trim=False): estimated_to_reference : float median time from each estimated boundary to the closest reference boundary - """ - validate_boundary(reference_intervals, estimated_intervals, trim) # Convert intervals to boundaries @@ -312,8 +300,7 @@ def deviation(reference_intervals, estimated_intervals, trim=False): if len(reference_boundaries) == 0 or len(estimated_boundaries) == 0: return np.nan, np.nan - dist = np.abs(np.subtract.outer(reference_boundaries, - estimated_boundaries)) + dist = np.abs(np.subtract.outer(reference_boundaries, estimated_boundaries)) estimated_to_reference = np.median(dist.min(axis=0)) reference_to_estimated = np.median(dist.min(axis=1)) @@ -321,9 +308,14 @@ def deviation(reference_intervals, estimated_intervals, trim=False): return reference_to_estimated, estimated_to_reference -def pairwise(reference_intervals, reference_labels, - estimated_intervals, estimated_labels, - frame_size=0.1, beta=1.0): +def pairwise( + reference_intervals, + reference_labels, + estimated_intervals, + estimated_labels, + frame_size=0.1, + beta=1.0, +): """Frame-clustering segmentation evaluation by pair-wise agreement. Examples @@ -376,25 +368,26 @@ def pairwise(reference_intervals, reference_labels, F-measure of detecting whether frames belong in the same cluster """ - validate_structure(reference_intervals, reference_labels, - estimated_intervals, estimated_labels) + validate_structure( + reference_intervals, reference_labels, estimated_intervals, estimated_labels + ) # Check for empty annotations. Don't need to check labels because # validate_structure makes sure they're the same size as intervals if reference_intervals.size == 0 or estimated_intervals.size == 0: - return 0., 0., 0. + return 0.0, 0.0, 0.0 # Generate the cluster labels - y_ref = util.intervals_to_samples(reference_intervals, - reference_labels, - sample_size=frame_size)[-1] + y_ref = util.intervals_to_samples( + reference_intervals, reference_labels, sample_size=frame_size + )[-1] y_ref = util.index_labels(y_ref)[0] # Map to index space - y_est = util.intervals_to_samples(estimated_intervals, - estimated_labels, - sample_size=frame_size)[-1] + y_est = util.intervals_to_samples( + estimated_intervals, estimated_labels, sample_size=frame_size + )[-1] y_est = util.index_labels(y_est)[0] @@ -418,9 +411,14 @@ def pairwise(reference_intervals, reference_labels, return precision, recall, f_measure -def rand_index(reference_intervals, reference_labels, - estimated_intervals, estimated_labels, - frame_size=0.1, beta=1.0): +def rand_index( + reference_intervals, + reference_labels, + estimated_intervals, + estimated_labels, + frame_size=0.1, + beta=1.0, +): """(Non-adjusted) Rand index. Examples @@ -467,28 +465,27 @@ def rand_index(reference_intervals, reference_labels, ------- rand_index : float > 0 Rand index - """ - - validate_structure(reference_intervals, reference_labels, - estimated_intervals, estimated_labels) + validate_structure( + reference_intervals, reference_labels, estimated_intervals, estimated_labels + ) # Check for empty annotations. Don't need to check labels because # validate_structure makes sure they're the same size as intervals if reference_intervals.size == 0 or estimated_intervals.size == 0: - return 0., 0., 0. + return 0.0, 0.0, 0.0 # Generate the cluster labels - y_ref = util.intervals_to_samples(reference_intervals, - reference_labels, - sample_size=frame_size)[-1] + y_ref = util.intervals_to_samples( + reference_intervals, reference_labels, sample_size=frame_size + )[-1] y_ref = util.index_labels(y_ref)[0] # Map to index space - y_est = util.intervals_to_samples(estimated_intervals, - estimated_labels, - sample_size=frame_size)[-1] + y_est = util.intervals_to_samples( + estimated_intervals, estimated_labels, sample_size=frame_size + )[-1] y_est = util.index_labels(y_est)[0] @@ -514,7 +511,7 @@ def rand_index(reference_intervals, reference_labels, def _contingency_matrix(reference_indices, estimated_indices): - """Computes the contingency matrix of a true labeling vs an estimated one. + """Compute the contingency matrix of a true labeling vs an estimated one. Parameters ---------- @@ -530,17 +527,16 @@ def _contingency_matrix(reference_indices, estimated_indices): .. note:: Based on sklearn.metrics.cluster.contingency_matrix """ - ref_classes, ref_class_idx = np.unique(reference_indices, - return_inverse=True) - est_classes, est_class_idx = np.unique(estimated_indices, - return_inverse=True) + ref_classes, ref_class_idx = np.unique(reference_indices, return_inverse=True) + est_classes, est_class_idx = np.unique(estimated_indices, return_inverse=True) n_ref_classes = ref_classes.shape[0] n_est_classes = est_classes.shape[0] # Using coo_matrix is faster than histogram2d - return scipy.sparse.coo_matrix((np.ones(ref_class_idx.shape[0]), - (ref_class_idx, est_class_idx)), - shape=(n_ref_classes, n_est_classes), - dtype=np.int64).toarray() + return scipy.sparse.coo_matrix( + (np.ones(ref_class_idx.shape[0]), (ref_class_idx, est_class_idx)), + shape=(n_ref_classes, n_est_classes), + dtype=np.int64, + ).toarray() def _adjusted_rand_index(reference_indices, estimated_indices): @@ -557,7 +553,6 @@ def _adjusted_rand_index(reference_indices, estimated_indices): ------- ari : float Adjusted Rand index - .. note:: Based on sklearn.metrics.cluster.adjusted_rand_score """ @@ -567,32 +562,39 @@ def _adjusted_rand_index(reference_indices, estimated_indices): # Special limit cases: no clustering since the data is not split; # or trivial clustering where each document is assigned a unique cluster. # These are perfect matches hence return 1.0. - if (ref_classes.shape[0] == est_classes.shape[0] == 1 or - ref_classes.shape[0] == est_classes.shape[0] == 0 or - (ref_classes.shape[0] == est_classes.shape[0] == - len(reference_indices))): + if ( + ref_classes.shape[0] == est_classes.shape[0] == 1 + or ref_classes.shape[0] == est_classes.shape[0] == 0 + or (ref_classes.shape[0] == est_classes.shape[0] == len(reference_indices)) + ): return 1.0 contingency = _contingency_matrix(reference_indices, estimated_indices) # Compute the ARI using the contingency data - sum_comb_c = sum(scipy.special.comb(n_c, 2, exact=1) for n_c in - contingency.sum(axis=1)) - sum_comb_k = sum(scipy.special.comb(n_k, 2, exact=1) for n_k in - contingency.sum(axis=0)) - - sum_comb = sum((scipy.special.comb(n_ij, 2, exact=1) for n_ij in - contingency.flatten())) - prod_comb = (sum_comb_c * sum_comb_k)/float(scipy.special.comb(n_samples, - 2)) - mean_comb = (sum_comb_k + sum_comb_c)/2. - return (sum_comb - prod_comb)/(mean_comb - prod_comb) - - -def ari(reference_intervals, reference_labels, - estimated_intervals, estimated_labels, - frame_size=0.1): - """Adjusted Rand Index (ARI) for frame clustering segmentation evaluation. + sum_comb_c = sum( + scipy.special.comb(n_c, 2, exact=1) for n_c in contingency.sum(axis=1) + ) + sum_comb_k = sum( + scipy.special.comb(n_k, 2, exact=1) for n_k in contingency.sum(axis=0) + ) + + sum_comb = sum( + (scipy.special.comb(n_ij, 2, exact=1) for n_ij in contingency.flatten()) + ) + prod_comb = (sum_comb_c * sum_comb_k) / float(scipy.special.comb(n_samples, 2)) + mean_comb = (sum_comb_k + sum_comb_c) / 2.0 + return (sum_comb - prod_comb) / (mean_comb - prod_comb) + + +def ari( + reference_intervals, + reference_labels, + estimated_intervals, + estimated_labels, + frame_size=0.1, +): + """Compute the Adjusted Rand Index (ARI) for frame clustering segmentation evaluation. Examples -------- @@ -635,25 +637,26 @@ def ari(reference_intervals, reference_labels, Adjusted Rand index between segmentations. """ - validate_structure(reference_intervals, reference_labels, - estimated_intervals, estimated_labels) + validate_structure( + reference_intervals, reference_labels, estimated_intervals, estimated_labels + ) # Check for empty annotations. Don't need to check labels because # validate_structure makes sure they're the same size as intervals if reference_intervals.size == 0 or estimated_intervals.size == 0: - return 0., 0., 0. + return 0.0, 0.0, 0.0 # Generate the cluster labels - y_ref = util.intervals_to_samples(reference_intervals, - reference_labels, - sample_size=frame_size)[-1] + y_ref = util.intervals_to_samples( + reference_intervals, reference_labels, sample_size=frame_size + )[-1] y_ref = util.index_labels(y_ref)[0] # Map to index space - y_est = util.intervals_to_samples(estimated_intervals, - estimated_labels, - sample_size=frame_size)[-1] + y_est = util.intervals_to_samples( + estimated_intervals, estimated_labels, sample_size=frame_size + )[-1] y_est = util.index_labels(y_est)[0] @@ -677,13 +680,13 @@ def _mutual_info_score(reference_indices, estimated_indices, contingency=None): ------- mi : float Mutual information - .. note:: Based on sklearn.metrics.cluster.mutual_info_score """ if contingency is None: - contingency = _contingency_matrix(reference_indices, - estimated_indices).astype(float) + contingency = _contingency_matrix(reference_indices, estimated_indices).astype( + float + ) contingency_sum = np.sum(contingency) pi = np.sum(contingency, axis=1) pj = np.sum(contingency, axis=0) @@ -696,13 +699,15 @@ def _mutual_info_score(reference_indices, estimated_indices, contingency=None): # log(a / b) should be calculated as log(a) - log(b) for # possible loss of precision log_outer = -np.log(outer[nnz]) + np.log(pi.sum()) + np.log(pj.sum()) - mi = (contingency_nm * (log_contingency_nm - np.log(contingency_sum)) + - contingency_nm * log_outer) + mi = ( + contingency_nm * (log_contingency_nm - np.log(contingency_sum)) + + contingency_nm * log_outer + ) return mi.sum() def _entropy(labels): - """Calculates the entropy for a labeling. + """Calculate the entropy for a labeling. Parameters ---------- @@ -713,7 +718,6 @@ def _entropy(labels): ------- entropy : float Entropy of the labeling. - .. note:: Based on sklearn.metrics.cluster.entropy """ @@ -736,7 +740,6 @@ def _adjusted_mutual_info_score(reference_indices, estimated_indices): ---------- reference_indices : np.ndarray Array of reference indices - estimated_indices : np.ndarray Array of estimated indices @@ -744,7 +747,6 @@ def _adjusted_mutual_info_score(reference_indices, estimated_indices): ------- ami : float <= 1.0 Mutual information - .. note:: Based on sklearn.metrics.cluster.adjusted_mutual_info_score and sklearn.metrics.cluster.expected_mutual_info_score @@ -754,14 +756,18 @@ def _adjusted_mutual_info_score(reference_indices, estimated_indices): est_classes = np.unique(estimated_indices) # Special limit cases: no clustering since the data is not split. # This is a perfect match hence return 1.0. - if (ref_classes.shape[0] == est_classes.shape[0] == 1 or - ref_classes.shape[0] == est_classes.shape[0] == 0): + if ( + ref_classes.shape[0] == est_classes.shape[0] == 1 + or ref_classes.shape[0] == est_classes.shape[0] == 0 + ): return 1.0 - contingency = _contingency_matrix(reference_indices, - estimated_indices).astype(float) + contingency = _contingency_matrix(reference_indices, estimated_indices).astype( + float + ) # Calculate the MI for the two clusterings - mi = _mutual_info_score(reference_indices, estimated_indices, - contingency=contingency) + mi = _mutual_info_score( + reference_indices, estimated_indices, contingency=contingency + ) # The following code is based on # sklearn.metrics.cluster.expected_mutual_information R, C = contingency.shape @@ -771,7 +777,7 @@ def _adjusted_mutual_info_score(reference_indices, estimated_indices): # There are three major terms to the EMI equation, which are multiplied to # and then summed over varying nij values. # While nijs[0] will never be used, having it simplifies the indexing. - nijs = np.arange(0, max(np.max(a), np.max(b)) + 1, dtype='float') + nijs = np.arange(0, max(np.max(a), np.max(b)) + 1, dtype="float") # Stops divide by zero warnings. As its not used, no issue. nijs[0] = 1 # term1 is nij / N @@ -790,7 +796,7 @@ def _adjusted_mutual_info_score(reference_indices, estimated_indices): gln_N = scipy.special.gammaln(N + 1) gln_nij = scipy.special.gammaln(nijs + 1) # start and end values for nij terms for each summation. - start = np.array([[v - N + w for w in b] for v in a], dtype='int') + start = np.array([[v - N + w for w in b] for v in a], dtype="int") start = np.maximum(start, 1) end = np.minimum(np.resize(a, (C, R)).T, np.resize(b, (R, C))) + 1 # emi itself is a summation over the various values. @@ -800,13 +806,19 @@ def _adjusted_mutual_info_score(reference_indices, estimated_indices): for nij in range(start[i, j], end[i, j]): term2 = log_Nnij[nij] - log_ab_outer[i, j] # Numerators are positive, denominators are negative. - gln = (gln_a[i] + gln_b[j] + gln_Na[i] + gln_Nb[j] - - gln_N - gln_nij[nij] - - scipy.special.gammaln(a[i] - nij + 1) - - scipy.special.gammaln(b[j] - nij + 1) - - scipy.special.gammaln(N - a[i] - b[j] + nij + 1)) + gln = ( + gln_a[i] + + gln_b[j] + + gln_Na[i] + + gln_Nb[j] + - gln_N + - gln_nij[nij] + - scipy.special.gammaln(a[i] - nij + 1) + - scipy.special.gammaln(b[j] - nij + 1) + - scipy.special.gammaln(N - a[i] - b[j] + nij + 1) + ) term3 = np.exp(gln) - emi += (term1[nij] * term2 * term3) + emi += term1[nij] * term2 * term3 # Calculate entropy for each labeling h_true, h_pred = _entropy(reference_indices), _entropy(estimated_indices) ami = (mi - emi) / (max(h_true, h_pred) - emi) @@ -821,7 +833,6 @@ def _normalized_mutual_info_score(reference_indices, estimated_indices): ---------- reference_indices : np.ndarray Array of reference indices - estimated_indices : np.ndarray Array of estimated indices @@ -829,7 +840,6 @@ def _normalized_mutual_info_score(reference_indices, estimated_indices): ------- nmi : float <= 1.0 Normalized mutual information - .. note:: Based on sklearn.metrics.cluster.normalized_mutual_info_score """ @@ -837,15 +847,19 @@ def _normalized_mutual_info_score(reference_indices, estimated_indices): est_classes = np.unique(estimated_indices) # Special limit cases: no clustering since the data is not split. # This is a perfect match hence return 1.0. - if (ref_classes.shape[0] == est_classes.shape[0] == 1 or - ref_classes.shape[0] == est_classes.shape[0] == 0): + if ( + ref_classes.shape[0] == est_classes.shape[0] == 1 + or ref_classes.shape[0] == est_classes.shape[0] == 0 + ): return 1.0 - contingency = _contingency_matrix(reference_indices, - estimated_indices).astype(float) - contingency = np.array(contingency, dtype='float') + contingency = _contingency_matrix(reference_indices, estimated_indices).astype( + float + ) + contingency = np.array(contingency, dtype="float") # Calculate the MI for the two clusterings - mi = _mutual_info_score(reference_indices, estimated_indices, - contingency=contingency) + mi = _mutual_info_score( + reference_indices, estimated_indices, contingency=contingency + ) # Calculate the expected value for the mutual information # Calculate entropy for each labeling h_true, h_pred = _entropy(reference_indices), _entropy(estimated_indices) @@ -853,9 +867,13 @@ def _normalized_mutual_info_score(reference_indices, estimated_indices): return nmi -def mutual_information(reference_intervals, reference_labels, - estimated_intervals, estimated_labels, - frame_size=0.1): +def mutual_information( + reference_intervals, + reference_labels, + estimated_intervals, + estimated_labels, + frame_size=0.1, +): """Frame-clustering segmentation: mutual information metrics. Examples @@ -905,25 +923,26 @@ def mutual_information(reference_intervals, reference_labels, Normalize mutual information between segmentations """ - validate_structure(reference_intervals, reference_labels, - estimated_intervals, estimated_labels) + validate_structure( + reference_intervals, reference_labels, estimated_intervals, estimated_labels + ) # Check for empty annotations. Don't need to check labels because # validate_structure makes sure they're the same size as intervals if reference_intervals.size == 0 or estimated_intervals.size == 0: - return 0., 0., 0. + return 0.0, 0.0, 0.0 # Generate the cluster labels - y_ref = util.intervals_to_samples(reference_intervals, - reference_labels, - sample_size=frame_size)[-1] + y_ref = util.intervals_to_samples( + reference_intervals, reference_labels, sample_size=frame_size + )[-1] y_ref = util.index_labels(y_ref)[0] # Map to index space - y_est = util.intervals_to_samples(estimated_intervals, - estimated_labels, - sample_size=frame_size)[-1] + y_est = util.intervals_to_samples( + estimated_intervals, estimated_labels, sample_size=frame_size + )[-1] y_est = util.index_labels(y_est)[0] @@ -939,8 +958,15 @@ def mutual_information(reference_intervals, reference_labels, return mutual_info, adj_mutual_info, norm_mutual_info -def nce(reference_intervals, reference_labels, estimated_intervals, - estimated_labels, frame_size=0.1, beta=1.0, marginal=False): +def nce( + reference_intervals, + reference_labels, + estimated_intervals, + estimated_labels, + frame_size=0.1, + beta=1.0, + marginal=False, +): """Frame-clustering segmentation: normalized conditional entropy Computes cross-entropy of cluster assignment, normalized by the @@ -985,7 +1011,6 @@ def nce(reference_intervals, reference_labels, estimated_intervals, beta : float > 0 beta for F-measure (Default value = 1.0) - marginal : bool If `False`, normalize conditional entropy by uniform entropy. If `True`, normalize conditional entropy by the marginal entropy. @@ -1001,7 +1026,6 @@ def nce(reference_intervals, reference_labels, estimated_intervals, - For `marginal=True`, ``1 - H(y_est | y_ref) / H(y_est)`` If `|y_est|==1`, then `S_over` will be 0. - S_under Under-clustering score: @@ -1010,31 +1034,29 @@ def nce(reference_intervals, reference_labels, estimated_intervals, - For `marginal=True`, ``1 - H(y_ref | y_est) / H(y_ref)`` If `|y_ref|==1`, then `S_under` will be 0. - S_F F-measure for (S_over, S_under) - """ - - validate_structure(reference_intervals, reference_labels, - estimated_intervals, estimated_labels) + validate_structure( + reference_intervals, reference_labels, estimated_intervals, estimated_labels + ) # Check for empty annotations. Don't need to check labels because # validate_structure makes sure they're the same size as intervals if reference_intervals.size == 0 or estimated_intervals.size == 0: - return 0., 0., 0. + return 0.0, 0.0, 0.0 # Generate the cluster labels - y_ref = util.intervals_to_samples(reference_intervals, - reference_labels, - sample_size=frame_size)[-1] + y_ref = util.intervals_to_samples( + reference_intervals, reference_labels, sample_size=frame_size + )[-1] y_ref = util.index_labels(y_ref)[0] # Map to index space - y_est = util.intervals_to_samples(estimated_intervals, - estimated_labels, - sample_size=frame_size)[-1] + y_est = util.intervals_to_samples( + estimated_intervals, estimated_labels, sample_size=frame_size + )[-1] y_est = util.index_labels(y_est)[0] @@ -1065,19 +1087,25 @@ def nce(reference_intervals, reference_labels, estimated_intervals, score_under = 0.0 if z_ref > 0: - score_under = 1. - true_given_est / z_ref + score_under = 1.0 - true_given_est / z_ref score_over = 0.0 if z_est > 0: - score_over = 1. - pred_given_ref / z_est + score_over = 1.0 - pred_given_ref / z_est f_measure = util.f_measure(score_over, score_under, beta=beta) return score_over, score_under, f_measure -def vmeasure(reference_intervals, reference_labels, estimated_intervals, - estimated_labels, frame_size=0.1, beta=1.0): +def vmeasure( + reference_intervals, + reference_labels, + estimated_intervals, + estimated_labels, + frame_size=0.1, + beta=1.0, +): """Frame-clustering segmentation: v-measure Computes cross-entropy of cluster assignment, normalized by the @@ -1132,22 +1160,23 @@ def vmeasure(reference_intervals, reference_labels, estimated_intervals, ``1 - H(y_est | y_ref) / H(y_est)`` If `|y_est|==1`, then `V_precision` will be 0. - V_recall Under-clustering score: ``1 - H(y_ref | y_est) / H(y_ref)`` If `|y_ref|==1`, then `V_recall` will be 0. - V_F F-measure for (V_precision, V_recall) - """ - - return nce(reference_intervals, reference_labels, - estimated_intervals, estimated_labels, - frame_size=frame_size, beta=beta, - marginal=True) + return nce( + reference_intervals, + reference_labels, + estimated_intervals, + estimated_labels, + frame_size=frame_size, + beta=beta, + marginal=True, + ) def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): @@ -1176,7 +1205,7 @@ def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): est_labels : list, shape=(m,) estimated segment labels, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -1185,68 +1214,84 @@ def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): scores : dict Dictionary of scores, where the key is the metric name (str) and the value is the (float) score achieved. - """ - # Adjust timespan of estimations relative to ground truth - ref_intervals, ref_labels = \ - util.adjust_intervals(ref_intervals, labels=ref_labels, t_min=0.0) + ref_intervals, ref_labels = util.adjust_intervals( + ref_intervals, labels=ref_labels, t_min=0.0 + ) - est_intervals, est_labels = \ - util.adjust_intervals(est_intervals, labels=est_labels, t_min=0.0, - t_max=ref_intervals.max()) + est_intervals, est_labels = util.adjust_intervals( + est_intervals, labels=est_labels, t_min=0.0, t_max=ref_intervals.max() + ) # Now compute all the metrics scores = collections.OrderedDict() # Boundary detection # Force these values for window - kwargs['window'] = .5 - scores['Precision@0.5'], scores['Recall@0.5'], scores['F-measure@0.5'] = \ - util.filter_kwargs(detection, ref_intervals, est_intervals, **kwargs) - - kwargs['window'] = 3.0 - scores['Precision@3.0'], scores['Recall@3.0'], scores['F-measure@3.0'] = \ - util.filter_kwargs(detection, ref_intervals, est_intervals, **kwargs) + kwargs["window"] = 0.5 + ( + scores["Precision@0.5"], + scores["Recall@0.5"], + scores["F-measure@0.5"], + ) = util.filter_kwargs(detection, ref_intervals, est_intervals, **kwargs) + + kwargs["window"] = 3.0 + ( + scores["Precision@3.0"], + scores["Recall@3.0"], + scores["F-measure@3.0"], + ) = util.filter_kwargs(detection, ref_intervals, est_intervals, **kwargs) # Boundary deviation - scores['Ref-to-est deviation'], scores['Est-to-ref deviation'] = \ - util.filter_kwargs(deviation, ref_intervals, est_intervals, **kwargs) + scores["Ref-to-est deviation"], scores["Est-to-ref deviation"] = util.filter_kwargs( + deviation, ref_intervals, est_intervals, **kwargs + ) # Pairwise clustering - (scores['Pairwise Precision'], - scores['Pairwise Recall'], - scores['Pairwise F-measure']) = util.filter_kwargs(pairwise, - ref_intervals, - ref_labels, - est_intervals, - est_labels, **kwargs) + ( + scores["Pairwise Precision"], + scores["Pairwise Recall"], + scores["Pairwise F-measure"], + ) = util.filter_kwargs( + pairwise, ref_intervals, ref_labels, est_intervals, est_labels, **kwargs + ) # Rand index - scores['Rand Index'] = util.filter_kwargs(rand_index, ref_intervals, - ref_labels, est_intervals, - est_labels, **kwargs) + scores["Rand Index"] = util.filter_kwargs( + rand_index, ref_intervals, ref_labels, est_intervals, est_labels, **kwargs + ) # Adjusted rand index - scores['Adjusted Rand Index'] = util.filter_kwargs(ari, ref_intervals, - ref_labels, - est_intervals, - est_labels, **kwargs) + scores["Adjusted Rand Index"] = util.filter_kwargs( + ari, ref_intervals, ref_labels, est_intervals, est_labels, **kwargs + ) # Mutual information metrics - (scores['Mutual Information'], - scores['Adjusted Mutual Information'], - scores['Normalized Mutual Information']) = \ - util.filter_kwargs(mutual_information, ref_intervals, ref_labels, - est_intervals, est_labels, **kwargs) + ( + scores["Mutual Information"], + scores["Adjusted Mutual Information"], + scores["Normalized Mutual Information"], + ) = util.filter_kwargs( + mutual_information, + ref_intervals, + ref_labels, + est_intervals, + est_labels, + **kwargs + ) # Conditional entropy metrics - scores['NCE Over'], scores['NCE Under'], scores['NCE F-measure'] = \ - util.filter_kwargs(nce, ref_intervals, ref_labels, est_intervals, - est_labels, **kwargs) + ( + scores["NCE Over"], + scores["NCE Under"], + scores["NCE F-measure"], + ) = util.filter_kwargs( + nce, ref_intervals, ref_labels, est_intervals, est_labels, **kwargs + ) # V-measure metrics - scores['V Precision'], scores['V Recall'], scores['V-measure'] = \ - util.filter_kwargs(vmeasure, ref_intervals, ref_labels, est_intervals, - est_labels, **kwargs) + scores["V Precision"], scores["V Recall"], scores["V-measure"] = util.filter_kwargs( + vmeasure, ref_intervals, ref_labels, est_intervals, est_labels, **kwargs + ) return scores diff --git a/mir_eval/separation.py b/mir_eval/separation.py index 474b33ba..0bb0704e 100644 --- a/mir_eval/separation.py +++ b/mir_eval/separation.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -''' +""" Source separation algorithms attempt to extract recordings of individual sources from a recording of a mixture of sources. Evaluation methods for source separation compare the extracted sources from reference sources and @@ -43,7 +43,7 @@ Trans. on Audio, Speech and Language Processing, 14(4):1462-1469, 2006. -''' +""" import numpy as np import scipy.fftpack @@ -60,7 +60,7 @@ def validate(reference_sources, estimated_sources): - """Checks that the input data to a metric are valid, and throws helpful + """Check that the input data to a metric are valid, and throws helpful errors if not. Parameters @@ -71,64 +71,77 @@ def validate(reference_sources, estimated_sources): matrix containing estimated sources """ - if reference_sources.shape != estimated_sources.shape: - raise ValueError('The shape of estimated sources and the true ' - 'sources should match. reference_sources.shape ' - '= {}, estimated_sources.shape ' - '= {}'.format(reference_sources.shape, - estimated_sources.shape)) + raise ValueError( + "The shape of estimated sources and the true " + "sources should match. reference_sources.shape " + "= {}, estimated_sources.shape " + "= {}".format(reference_sources.shape, estimated_sources.shape) + ) if reference_sources.ndim > 3 or estimated_sources.ndim > 3: - raise ValueError('The number of dimensions is too high (must be less ' - 'than 3). reference_sources.ndim = {}, ' - 'estimated_sources.ndim ' - '= {}'.format(reference_sources.ndim, - estimated_sources.ndim)) + raise ValueError( + "The number of dimensions is too high (must be less " + "than 3). reference_sources.ndim = {}, " + "estimated_sources.ndim " + "= {}".format(reference_sources.ndim, estimated_sources.ndim) + ) if reference_sources.size == 0: - warnings.warn("reference_sources is empty, should be of size " - "(nsrc, nsample). sdr, sir, sar, and perm will all " - "be empty np.ndarrays") + warnings.warn( + "reference_sources is empty, should be of size " + "(nsrc, nsample). sdr, sir, sar, and perm will all " + "be empty np.ndarrays" + ) elif _any_source_silent(reference_sources): - raise ValueError('All the reference sources should be non-silent (not ' - 'all-zeros), but at least one of the reference ' - 'sources is all 0s, which introduces ambiguity to the' - ' evaluation. (Otherwise we can add infinitely many ' - 'all-zero sources.)') + raise ValueError( + "All the reference sources should be non-silent (not " + "all-zeros), but at least one of the reference " + "sources is all 0s, which introduces ambiguity to the" + " evaluation. (Otherwise we can add infinitely many " + "all-zero sources.)" + ) if estimated_sources.size == 0: - warnings.warn("estimated_sources is empty, should be of size " - "(nsrc, nsample). sdr, sir, sar, and perm will all " - "be empty np.ndarrays") + warnings.warn( + "estimated_sources is empty, should be of size " + "(nsrc, nsample). sdr, sir, sar, and perm will all " + "be empty np.ndarrays" + ) elif _any_source_silent(estimated_sources): - raise ValueError('All the estimated sources should be non-silent (not ' - 'all-zeros), but at least one of the estimated ' - 'sources is all 0s. Since we require each reference ' - 'source to be non-silent, having a silent estimated ' - 'source will result in an underdetermined system.') - - if (estimated_sources.shape[0] > MAX_SOURCES or - reference_sources.shape[0] > MAX_SOURCES): - raise ValueError('The supplied matrices should be of shape (nsrc,' - ' nsampl) but reference_sources.shape[0] = {} and ' - 'estimated_sources.shape[0] = {} which is greater ' - 'than mir_eval.separation.MAX_SOURCES = {}. To ' - 'override this check, set ' - 'mir_eval.separation.MAX_SOURCES to a ' - 'larger value.'.format(reference_sources.shape[0], - estimated_sources.shape[0], - MAX_SOURCES)) + raise ValueError( + "All the estimated sources should be non-silent (not " + "all-zeros), but at least one of the estimated " + "sources is all 0s. Since we require each reference " + "source to be non-silent, having a silent estimated " + "source will result in an underdetermined system." + ) + + if ( + estimated_sources.shape[0] > MAX_SOURCES + or reference_sources.shape[0] > MAX_SOURCES + ): + raise ValueError( + "The supplied matrices should be of shape (nsrc," + " nsampl) but reference_sources.shape[0] = {} and " + "estimated_sources.shape[0] = {} which is greater " + "than mir_eval.separation.MAX_SOURCES = {}. To " + "override this check, set " + "mir_eval.separation.MAX_SOURCES to a " + "larger value.".format( + reference_sources.shape[0], estimated_sources.shape[0], MAX_SOURCES + ) + ) def _any_source_silent(sources): - """Returns true if the parameter sources has any silent first dimensions""" - return np.any(np.all(np.sum( - sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1)) + """Return true if the parameter sources has any silent first dimensions""" + return np.any( + np.all(np.sum(sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1) + ) -def bss_eval_sources(reference_sources, estimated_sources, - compute_permutation=True): +def bss_eval_sources(reference_sources, estimated_sources, compute_permutation=True): """ Ordering and measurement of the separation quality for estimated source signals in terms of filtered true source, interference and artifacts. @@ -184,7 +197,6 @@ def bss_eval_sources(reference_sources, estimated_sources, 92, pp. 1928-1936, 2012. """ - # make sure the input is of shape (nsrc, nsampl) if estimated_sources.ndim == 1: estimated_sources = estimated_sources[np.newaxis, :] @@ -206,18 +218,18 @@ def bss_eval_sources(reference_sources, estimated_sources, sar = np.empty((nsrc, nsrc)) for jest in range(nsrc): for jtrue in range(nsrc): - s_true, e_spat, e_interf, e_artif = \ - _bss_decomp_mtifilt(reference_sources, - estimated_sources[jest], - jtrue, 512) - sdr[jest, jtrue], sir[jest, jtrue], sar[jest, jtrue] = \ - _bss_source_crit(s_true, e_spat, e_interf, e_artif) + s_true, e_spat, e_interf, e_artif = _bss_decomp_mtifilt( + reference_sources, estimated_sources[jest], jtrue, 512 + ) + sdr[jest, jtrue], sir[jest, jtrue], sar[jest, jtrue] = _bss_source_crit( + s_true, e_spat, e_interf, e_artif + ) # select the best ordering perms = list(itertools.permutations(list(range(nsrc)))) mean_sir = np.empty(len(perms)) dum = np.arange(nsrc) - for (i, perm) in enumerate(perms): + for i, perm in enumerate(perms): mean_sir[i] = np.mean(sir[perm, dum]) popt = perms[np.argmax(mean_sir)] idx = (popt, dum) @@ -229,21 +241,23 @@ def bss_eval_sources(reference_sources, estimated_sources, sir = np.empty(nsrc) sar = np.empty(nsrc) for j in range(nsrc): - s_true, e_spat, e_interf, e_artif = \ - _bss_decomp_mtifilt(reference_sources, - estimated_sources[j], - j, 512) - sdr[j], sir[j], sar[j] = \ - _bss_source_crit(s_true, e_spat, e_interf, e_artif) + s_true, e_spat, e_interf, e_artif = _bss_decomp_mtifilt( + reference_sources, estimated_sources[j], j, 512 + ) + sdr[j], sir[j], sar[j] = _bss_source_crit(s_true, e_spat, e_interf, e_artif) # return the default permutation for compatibility popt = np.arange(nsrc) return (sdr, sir, sar, popt) -def bss_eval_sources_framewise(reference_sources, estimated_sources, - window=30*44100, hop=15*44100, - compute_permutation=False): +def bss_eval_sources_framewise( + reference_sources, + estimated_sources, + window=30 * 44100, + hop=15 * 44100, + compute_permutation=False, +): """Framewise computation of bss_eval_sources Please be aware that this function does not compute permutations (by @@ -303,9 +317,7 @@ def bss_eval_sources_framewise(reference_sources, estimated_sources, the mean SIR sense (estimated source number ``perm[j]`` corresponds to true source number ``j``). Note: ``perm`` will be ``range(nsrc)`` for all windows if ``compute_permutation`` is ``False`` - """ - # make sure the input is of shape (nsrc, nsampl) if estimated_sources.ndim == 1: estimated_sources = estimated_sources[np.newaxis, :] @@ -319,14 +331,12 @@ def bss_eval_sources_framewise(reference_sources, estimated_sources, nsrc = reference_sources.shape[0] - nwin = int( - np.floor((reference_sources.shape[1] - window + hop) / hop) - ) + nwin = int(np.floor((reference_sources.shape[1] - window + hop) / hop)) # if fewer than 2 windows would be evaluated, return the sources result if nwin < 2: - result = bss_eval_sources(reference_sources, - estimated_sources, - compute_permutation) + result = bss_eval_sources( + reference_sources, estimated_sources, compute_permutation + ) return [np.expand_dims(score, -1) for score in result] # compute the criteria across all windows @@ -341,8 +351,7 @@ def bss_eval_sources_framewise(reference_sources, estimated_sources, ref_slice = reference_sources[:, win_slice] est_slice = estimated_sources[:, win_slice] # check for a silent frame - if (not _any_source_silent(ref_slice) and - not _any_source_silent(est_slice)): + if not _any_source_silent(ref_slice) and not _any_source_silent(est_slice): sdr[:, k], sir[:, k], sar[:, k], perm[:, k] = bss_eval_sources( ref_slice, est_slice, compute_permutation ) @@ -353,9 +362,8 @@ def bss_eval_sources_framewise(reference_sources, estimated_sources, return sdr, sir, sar, perm -def bss_eval_images(reference_sources, estimated_sources, - compute_permutation=True): - """Implementation of the bss_eval_images function from the +def bss_eval_images(reference_sources, estimated_sources, compute_permutation=True): + """Compute the bss_eval_images function from the BSS_EVAL Matlab toolbox. Ordering and measurement of the separation quality for estimated source @@ -411,9 +419,7 @@ def bss_eval_images(reference_sources, estimated_sources, Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign (2007-2010): Achievements and remaining challenges", Signal Processing, 92, pp. 1928-1936, 2012. - """ - # make sure the input has 3 dimensions # assuming input is in shape (nsampl) or (nsrc, nsampl) estimated_sources = np.atleast_3d(estimated_sources) @@ -423,8 +429,7 @@ def bss_eval_images(reference_sources, estimated_sources, validate(reference_sources, estimated_sources) # If empty matrices were supplied, return empty lists (special case) if reference_sources.size == 0 or estimated_sources.size == 0: - return np.array([]), np.array([]), np.array([]), \ - np.array([]), np.array([]) + return np.array([]), np.array([]), np.array([]), np.array([]), np.array([]) # determine size parameters nsrc = estimated_sources.shape[0] @@ -440,26 +445,24 @@ def bss_eval_images(reference_sources, estimated_sources, sar = np.empty((nsrc, nsrc)) for jest in range(nsrc): for jtrue in range(nsrc): - s_true, e_spat, e_interf, e_artif = \ - _bss_decomp_mtifilt_images( - reference_sources, - np.reshape( - estimated_sources[jest], - (nsampl, nchan), - order='F' - ), - jtrue, - 512 - ) - sdr[jest, jtrue], isr[jest, jtrue], \ - sir[jest, jtrue], sar[jest, jtrue] = \ - _bss_image_crit(s_true, e_spat, e_interf, e_artif) + s_true, e_spat, e_interf, e_artif = _bss_decomp_mtifilt_images( + reference_sources, + np.reshape(estimated_sources[jest], (nsampl, nchan), order="F"), + jtrue, + 512, + ) + ( + sdr[jest, jtrue], + isr[jest, jtrue], + sir[jest, jtrue], + sar[jest, jtrue], + ) = _bss_image_crit(s_true, e_spat, e_interf, e_artif) # select the best ordering perms = list(itertools.permutations(list(range(nsrc)))) mean_sir = np.empty(len(perms)) dum = np.arange(nsrc) - for (i, perm) in enumerate(perms): + for i, perm in enumerate(perms): mean_sir[i] = np.mean(sir[perm, dum]) popt = perms[np.argmax(mean_sir)] idx = (popt, dum) @@ -471,28 +474,35 @@ def bss_eval_images(reference_sources, estimated_sources, isr = np.empty(nsrc) sir = np.empty(nsrc) sar = np.empty(nsrc) - Gj = [0] * nsrc # prepare G matrics with zeroes + Gj = [0] * nsrc # prepare G matrices with zeroes G = np.zeros(1) for j in range(nsrc): # save G matrix to avoid recomputing it every call - s_true, e_spat, e_interf, e_artif, Gj_temp, G = \ - _bss_decomp_mtifilt_images(reference_sources, - np.reshape(estimated_sources[j], - (nsampl, nchan), - order='F'), - j, 512, Gj[j], G) + s_true, e_spat, e_interf, e_artif, Gj_temp, G = _bss_decomp_mtifilt_images( + reference_sources, + np.reshape(estimated_sources[j], (nsampl, nchan), order="F"), + j, + 512, + Gj[j], + G, + ) Gj[j] = Gj_temp - sdr[j], isr[j], sir[j], sar[j] = \ - _bss_image_crit(s_true, e_spat, e_interf, e_artif) + sdr[j], isr[j], sir[j], sar[j] = _bss_image_crit( + s_true, e_spat, e_interf, e_artif + ) # return the default permutation for compatibility popt = np.arange(nsrc) return (sdr, isr, sir, sar, popt) -def bss_eval_images_framewise(reference_sources, estimated_sources, - window=30*44100, hop=15*44100, - compute_permutation=False): +def bss_eval_images_framewise( + reference_sources, + estimated_sources, + window=30 * 44100, + hop=15 * 44100, + compute_permutation=False, +): """Framewise computation of bss_eval_images Please be aware that this function does not compute permutations (by @@ -554,9 +564,7 @@ def bss_eval_images_framewise(reference_sources, estimated_sources, true source number j) Note: perm will be range(nsrc) for all windows if compute_permutation is False - """ - # make sure the input has 3 dimensions # assuming input is in shape (nsampl) or (nsrc, nsampl) estimated_sources = np.atleast_3d(estimated_sources) @@ -570,14 +578,12 @@ def bss_eval_images_framewise(reference_sources, estimated_sources, nsrc = reference_sources.shape[0] - nwin = int( - np.floor((reference_sources.shape[1] - window + hop) / hop) - ) + nwin = int(np.floor((reference_sources.shape[1] - window + hop) / hop)) # if fewer than 2 windows would be evaluated, return the images result if nwin < 2: - result = bss_eval_images(reference_sources, - estimated_sources, - compute_permutation) + result = bss_eval_images( + reference_sources, estimated_sources, compute_permutation + ) return [np.expand_dims(score, -1) for score in result] # compute the criteria across all windows @@ -593,12 +599,10 @@ def bss_eval_images_framewise(reference_sources, estimated_sources, ref_slice = reference_sources[:, win_slice, :] est_slice = estimated_sources[:, win_slice, :] # check for a silent frame - if (not _any_source_silent(ref_slice) and - not _any_source_silent(est_slice)): - sdr[:, k], isr[:, k], sir[:, k], sar[:, k], perm[:, k] = \ - bss_eval_images( - ref_slice, est_slice, compute_permutation - ) + if not _any_source_silent(ref_slice) and not _any_source_silent(est_slice): + sdr[:, k], isr[:, k], sir[:, k], sar[:, k], perm[:, k] = bss_eval_images( + ref_slice, est_slice, compute_permutation + ) else: # if we have a silent frame set results as np.nan sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan @@ -617,19 +621,20 @@ def _bss_decomp_mtifilt(reference_sources, estimated_source, j, flen): # true source image s_true = np.hstack((reference_sources[j], np.zeros(flen - 1))) # spatial (or filtering) distortion - e_spat = _project(reference_sources[j, np.newaxis, :], estimated_source, - flen) - s_true + e_spat = ( + _project(reference_sources[j, np.newaxis, :], estimated_source, flen) - s_true + ) # interference - e_interf = _project(reference_sources, - estimated_source, flen) - s_true - e_spat + e_interf = _project(reference_sources, estimated_source, flen) - s_true - e_spat # artifacts e_artif = -s_true - e_spat - e_interf e_artif[:nsampl] += estimated_source return (s_true, e_spat, e_interf, e_artif) -def _bss_decomp_mtifilt_images(reference_sources, estimated_source, j, flen, - Gj=None, G=None): +def _bss_decomp_mtifilt_images( + reference_sources, estimated_source, j, flen, Gj=None, G=None +): """Decomposition of an estimated source image into four components representing respectively the true source image, spatial (or filtering) distortion, interference and artifacts, derived from the true source @@ -638,7 +643,7 @@ def _bss_decomp_mtifilt_images(reference_sources, estimated_source, j, flen, Improved performance can be gained by passing Gj and G parameters initially as all zeros. These parameters store the results from the computation of the G matrix in _project_images and then return them for subsequent calls - to this function. This only works when not computing permuations. + to this function. This only works when not computing permutations. """ nsampl = np.shape(estimated_source)[0] nchan = np.shape(estimated_source)[1] @@ -646,25 +651,27 @@ def _bss_decomp_mtifilt_images(reference_sources, estimated_source, j, flen, saveg = Gj is not None and G is not None # decomposition # true source image - s_true = np.hstack((np.reshape(reference_sources[j], - (nsampl, nchan), - order="F").transpose(), - np.zeros((nchan, flen - 1)))) + s_true = np.hstack( + ( + np.reshape(reference_sources[j], (nsampl, nchan), order="F").transpose(), + np.zeros((nchan, flen - 1)), + ) + ) # spatial (or filtering) distortion if saveg: - e_spat, Gj = _project_images(reference_sources[j, np.newaxis, :], - estimated_source, flen, Gj) + e_spat, Gj = _project_images( + reference_sources[j, np.newaxis, :], estimated_source, flen, Gj + ) else: - e_spat = _project_images(reference_sources[j, np.newaxis, :], - estimated_source, flen) + e_spat = _project_images( + reference_sources[j, np.newaxis, :], estimated_source, flen + ) e_spat = e_spat - s_true # interference if saveg: - e_interf, G = _project_images(reference_sources, - estimated_source, flen, G) + e_interf, G = _project_images(reference_sources, estimated_source, flen, G) else: - e_interf = _project_images(reference_sources, - estimated_source, flen) + e_interf = _project_images(reference_sources, estimated_source, flen) e_interf = e_interf - s_true - e_spat # artifacts e_artif = -s_true - e_spat - e_interf @@ -685,10 +692,9 @@ def _project(reference_sources, estimated_source, flen): # computing coefficients of least squares problem via FFT ## # zero padding and FFT of input data - reference_sources = np.hstack((reference_sources, - np.zeros((nsrc, flen - 1)))) + reference_sources = np.hstack((reference_sources, np.zeros((nsrc, flen - 1)))) estimated_source = np.hstack((estimated_source, np.zeros(flen - 1))) - n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.))) + n_fft = int(2 ** np.ceil(np.log2(nsampl + flen - 1.0))) sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1) sef = scipy.fftpack.fft(estimated_source, n=n_fft) # inner products between delayed versions of reference_sources @@ -697,28 +703,27 @@ def _project(reference_sources, estimated_source, flen): for j in range(nsrc): ssf = sf[i] * np.conj(sf[j]) ssf = np.real(scipy.fftpack.ifft(ssf)) - ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), - r=ssf[:flen]) - G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss - G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T + ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), r=ssf[:flen]) + G[i * flen : (i + 1) * flen, j * flen : (j + 1) * flen] = ss + G[j * flen : (j + 1) * flen, i * flen : (i + 1) * flen] = ss.T # inner products between estimated_source and delayed versions of # reference_sources D = np.zeros(nsrc * flen) for i in range(nsrc): ssef = sf[i] * np.conj(sef) ssef = np.real(scipy.fftpack.ifft(ssef)) - D[i * flen: (i+1) * flen] = np.hstack((ssef[0], ssef[-1:-flen:-1])) + D[i * flen : (i + 1) * flen] = np.hstack((ssef[0], ssef[-1:-flen:-1])) # Computing projection # Distortion filters try: - C = np.linalg.solve(G, D).reshape(flen, nsrc, order='F') + C = np.linalg.solve(G, D).reshape(flen, nsrc, order="F") except np.linalg.linalg.LinAlgError: - C = np.linalg.lstsq(G, D)[0].reshape(flen, nsrc, order='F') + C = np.linalg.lstsq(G, D)[0].reshape(flen, nsrc, order="F") # Filtering sproj = np.zeros(nsampl + flen - 1) for i in range(nsrc): - sproj += fftconvolve(C[:, i], reference_sources[i])[:nsampl + flen - 1] + sproj += fftconvolve(C[:, i], reference_sources[i])[: nsampl + flen - 1] return sproj @@ -732,16 +737,19 @@ def _project_images(reference_sources, estimated_source, flen, G=None): nsrc = reference_sources.shape[0] nsampl = reference_sources.shape[1] nchan = reference_sources.shape[2] - reference_sources = np.reshape(np.transpose(reference_sources, (2, 0, 1)), - (nchan*nsrc, nsampl), order='F') + reference_sources = np.reshape( + np.transpose(reference_sources, (2, 0, 1)), (nchan * nsrc, nsampl), order="F" + ) # computing coefficients of least squares problem via FFT ## # zero padding and FFT of input data - reference_sources = np.hstack((reference_sources, - np.zeros((nchan*nsrc, flen - 1)))) - estimated_source = \ - np.hstack((estimated_source.transpose(), np.zeros((nchan, flen - 1)))) - n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.))) + reference_sources = np.hstack( + (reference_sources, np.zeros((nchan * nsrc, flen - 1))) + ) + estimated_source = np.hstack( + (estimated_source.transpose(), np.zeros((nchan, flen - 1))) + ) + n_fft = int(2 ** np.ceil(np.log2(nsampl + flen - 1.0))) sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1) sef = scipy.fftpack.fft(estimated_source, n=n_fft) @@ -750,25 +758,23 @@ def _project_images(reference_sources, estimated_source, flen, G=None): saveg = False G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen)) for i in range(nchan * nsrc): - for j in range(i+1): + for j in range(i + 1): ssf = sf[i] * np.conj(sf[j]) ssf = np.real(scipy.fftpack.ifft(ssf)) - ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), - r=ssf[:flen]) - G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss - G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T + ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), r=ssf[:flen]) + G[i * flen : (i + 1) * flen, j * flen : (j + 1) * flen] = ss + G[j * flen : (j + 1) * flen, i * flen : (i + 1) * flen] = ss.T else: # avoid recomputing G (only works if no permutation is desired) saveg = True # return G if np.all(G == 0): # only compute G if passed as 0 G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen)) for i in range(nchan * nsrc): - for j in range(i+1): + for j in range(i + 1): ssf = sf[i] * np.conj(sf[j]) ssf = np.real(scipy.fftpack.ifft(ssf)) - ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), - r=ssf[:flen]) - G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss - G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T + ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), r=ssf[:flen]) + G[i * flen : (i + 1) * flen, j * flen : (j + 1) * flen] = ss + G[j * flen : (j + 1) * flen, i * flen : (i + 1) * flen] = ss.T # inner products between estimated_source and delayed versions of # reference_sources @@ -777,22 +783,23 @@ def _project_images(reference_sources, estimated_source, flen, G=None): for i in range(nchan): ssef = sf[k] * np.conj(sef[i]) ssef = np.real(scipy.fftpack.ifft(ssef)) - D[k * flen: (k+1) * flen, i] = \ - np.hstack((ssef[0], ssef[-1:-flen:-1])).transpose() + D[k * flen : (k + 1) * flen, i] = np.hstack( + (ssef[0], ssef[-1:-flen:-1]) + ).transpose() # Computing projection # Distortion filters try: - C = np.linalg.solve(G, D).reshape(flen, nchan*nsrc, nchan, order='F') + C = np.linalg.solve(G, D).reshape(flen, nchan * nsrc, nchan, order="F") except np.linalg.linalg.LinAlgError: - C = np.linalg.lstsq(G, D)[0].reshape(flen, nchan*nsrc, nchan, - order='F') + C = np.linalg.lstsq(G, D)[0].reshape(flen, nchan * nsrc, nchan, order="F") # Filtering sproj = np.zeros((nchan, nsampl + flen - 1)) for k in range(nchan * nsrc): for i in range(nchan): - sproj[i] += fftconvolve(C[:, k, i].transpose(), - reference_sources[k])[:nsampl + flen - 1] + sproj[i] += fftconvolve(C[:, k, i].transpose(), reference_sources[k])[ + : nsampl + flen - 1 + ] # return G only if it was passed in if saveg: return sproj, G @@ -806,9 +813,9 @@ def _bss_source_crit(s_true, e_spat, e_interf, e_artif): """ # energy ratios s_filt = s_true + e_spat - sdr = _safe_db(np.sum(s_filt**2), np.sum((e_interf + e_artif)**2)) + sdr = _safe_db(np.sum(s_filt**2), np.sum((e_interf + e_artif) ** 2)) sir = _safe_db(np.sum(s_filt**2), np.sum(e_interf**2)) - sar = _safe_db(np.sum((s_filt + e_interf)**2), np.sum(e_artif**2)) + sar = _safe_db(np.sum((s_filt + e_interf) ** 2), np.sum(e_artif**2)) return (sdr, sir, sar) @@ -817,10 +824,10 @@ def _bss_image_crit(s_true, e_spat, e_interf, e_artif): filtered true source, spatial error, interference and artifacts. """ # energy ratios - sdr = _safe_db(np.sum(s_true**2), np.sum((e_spat+e_interf+e_artif)**2)) + sdr = _safe_db(np.sum(s_true**2), np.sum((e_spat + e_interf + e_artif) ** 2)) isr = _safe_db(np.sum(s_true**2), np.sum(e_spat**2)) - sir = _safe_db(np.sum((s_true+e_spat)**2), np.sum(e_interf**2)) - sar = _safe_db(np.sum((s_true+e_spat+e_interf)**2), np.sum(e_artif**2)) + sir = _safe_db(np.sum((s_true + e_spat) ** 2), np.sum(e_interf**2)) + sar = _safe_db(np.sum((s_true + e_spat + e_interf) ** 2), np.sum(e_artif**2)) return (sdr, isr, sir, sar) @@ -856,7 +863,7 @@ def evaluate(reference_sources, estimated_sources, **kwargs): matrix containing true sources estimated_sources : np.ndarray, shape=(nsrc, nsampl[, nchan]) matrix containing estimated sources - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -871,51 +878,39 @@ def evaluate(reference_sources, estimated_sources, **kwargs): scores = collections.OrderedDict() sdr, isr, sir, sar, perm = util.filter_kwargs( - bss_eval_images, - reference_sources, - estimated_sources, - **kwargs + bss_eval_images, reference_sources, estimated_sources, **kwargs ) - scores['Images - Source to Distortion'] = sdr.tolist() - scores['Images - Image to Spatial'] = isr.tolist() - scores['Images - Source to Interference'] = sir.tolist() - scores['Images - Source to Artifact'] = sar.tolist() - scores['Images - Source permutation'] = perm.tolist() + scores["Images - Source to Distortion"] = sdr.tolist() + scores["Images - Image to Spatial"] = isr.tolist() + scores["Images - Source to Interference"] = sir.tolist() + scores["Images - Source to Artifact"] = sar.tolist() + scores["Images - Source permutation"] = perm.tolist() sdr, isr, sir, sar, perm = util.filter_kwargs( - bss_eval_images_framewise, - reference_sources, - estimated_sources, - **kwargs + bss_eval_images_framewise, reference_sources, estimated_sources, **kwargs ) - scores['Images Frames - Source to Distortion'] = sdr.tolist() - scores['Images Frames - Image to Spatial'] = isr.tolist() - scores['Images Frames - Source to Interference'] = sir.tolist() - scores['Images Frames - Source to Artifact'] = sar.tolist() - scores['Images Frames - Source permutation'] = perm.tolist() + scores["Images Frames - Source to Distortion"] = sdr.tolist() + scores["Images Frames - Image to Spatial"] = isr.tolist() + scores["Images Frames - Source to Interference"] = sir.tolist() + scores["Images Frames - Source to Artifact"] = sar.tolist() + scores["Images Frames - Source permutation"] = perm.tolist() # Verify we can compute sources on this input if reference_sources.ndim < 3 and estimated_sources.ndim < 3: sdr, sir, sar, perm = util.filter_kwargs( - bss_eval_sources_framewise, - reference_sources, - estimated_sources, - **kwargs + bss_eval_sources_framewise, reference_sources, estimated_sources, **kwargs ) - scores['Sources Frames - Source to Distortion'] = sdr.tolist() - scores['Sources Frames - Source to Interference'] = sir.tolist() - scores['Sources Frames - Source to Artifact'] = sar.tolist() - scores['Sources Frames - Source permutation'] = perm.tolist() + scores["Sources Frames - Source to Distortion"] = sdr.tolist() + scores["Sources Frames - Source to Interference"] = sir.tolist() + scores["Sources Frames - Source to Artifact"] = sar.tolist() + scores["Sources Frames - Source permutation"] = perm.tolist() sdr, sir, sar, perm = util.filter_kwargs( - bss_eval_sources, - reference_sources, - estimated_sources, - **kwargs + bss_eval_sources, reference_sources, estimated_sources, **kwargs ) - scores['Sources - Source to Distortion'] = sdr.tolist() - scores['Sources - Source to Interference'] = sir.tolist() - scores['Sources - Source to Artifact'] = sar.tolist() - scores['Sources - Source permutation'] = perm.tolist() + scores["Sources - Source to Distortion"] = sdr.tolist() + scores["Sources - Source to Interference"] = sir.tolist() + scores["Sources - Source to Artifact"] = sar.tolist() + scores["Sources - Source permutation"] = perm.tolist() return scores diff --git a/mir_eval/sonify.py b/mir_eval/sonify.py index 13a2ea3a..c3b3cdd3 100644 --- a/mir_eval/sonify.py +++ b/mir_eval/sonify.py @@ -1,7 +1,7 @@ -''' +""" Methods which sonify annotations for "evaluation by ear". All functions return a raw signal at the specified sampling rate. -''' +""" import numpy as np from numpy.lib.stride_tricks import as_strided @@ -12,7 +12,7 @@ def clicks(times, fs, click=None, length=None): - """Returns a signal with the signal 'click' placed at each specified time + """Return a signal with the signal 'click' placed at each specified time Parameters ---------- @@ -35,34 +35,35 @@ def clicks(times, fs, click=None, length=None): # Create default click signal if click is None: # 1 kHz tone, 100ms - click = np.sin(2*np.pi*np.arange(fs*.1)*1000/(1.*fs)) + click = np.sin(2 * np.pi * np.arange(fs * 0.1) * 1000 / (1.0 * fs)) # Exponential decay - click *= np.exp(-np.arange(fs*.1)/(fs*.01)) + click *= np.exp(-np.arange(fs * 0.1) / (fs * 0.01)) # Set default length if length is None: - length = int(times.max()*fs + click.shape[0] + 1) + length = int(times.max() * fs + click.shape[0] + 1) # Pre-allocate click signal click_signal = np.zeros(length) # Place clicks for time in times: # Compute the boundaries of the click - start = int(time*fs) + start = int(time * fs) end = start + click.shape[0] # Make sure we don't try to output past the end of the signal if start >= length: break if end >= length: - click_signal[start:] = click[:length - start] + click_signal[start:] = click[: length - start] break # Normally, just add a click here click_signal[start:end] = click return click_signal -def time_frequency(gram, frequencies, times, fs, function=np.sin, length=None, - n_dec=1, threshold=0.01): - """Reverse synthesis of a time-frequency representation of a signal +def time_frequency( + gram, frequencies, times, fs, function=np.sin, length=None, n_dec=1, threshold=0.01 +): + r"""Reverse synthesis of a time-frequency representation of a signal Parameters ---------- @@ -75,19 +76,25 @@ def time_frequency(gram, frequencies, times, fs, function=np.sin, length=None, frequencies : np.ndarray array of size ``gram.shape[0]`` denoting the frequency (in Hz) of each row of gram + times : np.ndarray, shape= ``(gram.shape[1],)`` or ``(gram.shape[1], 2)`` Either the start time (in seconds) of each column in the gram, or the time interval (in seconds) corresponding to each column. + fs : int desired sampling rate of the output signal + function : function function to use to synthesize notes, should be :math:`2\pi`-periodic + length : int desired number of samples in the output signal, defaults to ``times[-1]*fs`` + n_dec : int the number of decimals used to approximate each sonfied frequency. Defaults to 1 decimal place. Higher precision will be slower. + threshold : float optimizes synthesis to only occur for frequencies that have a linear magnitude of at least one element in gram above the given threshold. @@ -119,9 +126,9 @@ def time_frequency(gram, frequencies, times, fs, function=np.sin, length=None, sample_intervals = np.round(times * fs).astype(int) def _fast_synthesize(frequency): - """A faster way to synthesize a signal. - Generate one cycle, and simulate arbitrary repetitions - using array indexing tricks. + """Efficiently synthesize a signal. + Generate one cycle, and simulate arbitrary repetitions + using array indexing tricks. """ # hack so that we can ensure an integer number of periods and samples # rounds frequency to 1st decimal, s.t. 10 * frequency will be an int @@ -134,26 +141,29 @@ def _fast_synthesize(frequency): # is an integer n_samples = int(10.0**n_dec * fs) - short_signal = function(2.0 * np.pi * np.arange(n_samples) * - frequency / fs) + short_signal = function(2.0 * np.pi * np.arange(n_samples) * frequency / fs) # Calculate the number of loops we need to fill the duration - n_repeats = int(np.ceil(length/float(short_signal.shape[0]))) + n_repeats = int(np.ceil(length / float(short_signal.shape[0]))) # Simulate tiling the short buffer by using stride tricks - long_signal = as_strided(short_signal, - shape=(n_repeats, len(short_signal)), - strides=(0, short_signal.itemsize)) + long_signal = as_strided( + short_signal, + shape=(n_repeats, len(short_signal)), + strides=(0, short_signal.itemsize), + ) # Use a flatiter to simulate a long 1D buffer return long_signal.flat def _const_interpolator(value): """Return a function that returns `value` - no matter the input. + no matter the input. """ + def __interpolator(x): return value + return __interpolator # Threshold the tfgram to remove non-positive values @@ -165,7 +175,7 @@ def __interpolator(x): # Check if there is at least one element on each frequency that has a value above the threshold # to justify processing, for optimisation. - spectral_max_magnitudes = np.max(gram, axis = 1) + spectral_max_magnitudes = np.max(gram, axis=1) for n, frequency in enumerate(frequencies): if spectral_max_magnitudes[n] < threshold: continue @@ -178,17 +188,25 @@ def __interpolator(x): # (len, 1) to (len-1, 2), and hence differ from the length of gram (i.e one less), # so we ensure gram is reduced appropriately. gram_interpolator = interp1d( - time_centers, gram[n, :n_times], - kind='linear', bounds_error=False, - fill_value=(gram[n, 0], gram[n, -1])) + time_centers, + gram[n, :n_times], + kind="linear", + bounds_error=False, + fill_value=(gram[n, 0], gram[n, -1]), + ) # If only one time point, create constant interpolator else: gram_interpolator = _const_interpolator(gram[n, 0]) # Create the time-varying scaling for the entire time interval by the piano roll # magnitude and add to the accumulating waveform. - output += wave[:length] * gram_interpolator(np.arange(max(sample_intervals[0][0], 0), - min(sample_intervals[-1][-1], length))) + # FIXME: this logic is broken when length + # does not match the final sample interval + output += wave[:length] * gram_interpolator( + np.arange( + max(sample_intervals[0][0], 0), min(sample_intervals[-1][-1], length) + ) + ) # Normalize, but only if there's non-zero values norm = np.abs(output).max() @@ -198,33 +216,28 @@ def __interpolator(x): return output -def pitch_contour(times, frequencies, fs, amplitudes=None, function=np.sin, - length=None, kind='linear'): - '''Sonify a pitch contour. +def pitch_contour( + times, frequencies, fs, amplitudes=None, function=np.sin, length=None, kind="linear" +): + r"""Sonify a pitch contour. Parameters ---------- times : np.ndarray time indices for each frequency measurement, in seconds - frequencies : np.ndarray frequency measurements, in Hz. Non-positive measurements will be interpreted as un-voiced samples. - fs : int desired sampling rate of the output signal - amplitudes : np.ndarray - amplitude measurments, nonnegative + amplitude measurements, nonnegative defaults to ``np.ones((length,))`` - function : function function to use to synthesize notes, should be :math:`2\pi`-periodic - length : int desired number of samples in the output signal, defaults to ``max(times)*fs`` - kind : str Interpolation mode for the frequency and amplitude values. See: ``scipy.interpolate.interp1d`` for valid settings. @@ -233,8 +246,7 @@ def pitch_contour(times, frequencies, fs, amplitudes=None, function=np.sin, ------- output : np.ndarray synthesized version of the pitch contour - ''' - + """ fs = float(fs) if length is None: @@ -245,19 +257,30 @@ def pitch_contour(times, frequencies, fs, amplitudes=None, function=np.sin, frequencies = np.maximum(frequencies, 0.0) # Build a frequency interpolator - f_interp = interp1d(times * fs, 2 * np.pi * frequencies / fs, kind=kind, - fill_value=0.0, bounds_error=False, copy=False) + f_interp = interp1d( + times * fs, + 2 * np.pi * frequencies / fs, + kind=kind, + fill_value=0.0, + bounds_error=False, + copy=False, + ) # Estimate frequency at sample points f_est = f_interp(np.arange(length)) if amplitudes is None: - a_est = np.ones((length, )) + a_est = np.ones((length,)) else: # build an amplitude interpolator a_interp = interp1d( - times * fs, amplitudes, kind=kind, - fill_value=0.0, bounds_error=False, copy=False) + times * fs, + amplitudes, + kind=kind, + fill_value=0.0, + bounds_error=False, + copy=False, + ) a_est = a_interp(np.arange(length)) # Sonify the waveform @@ -273,12 +296,12 @@ def chroma(chromagram, times, fs, **kwargs): Chromagram matrix, where each row represents a semitone [C->Bb] i.e., ``chromagram[3, j]`` is the magnitude of D# from ``times[j]`` to ``times[j + 1]`` - times: np.ndarray, shape=(len(chord_labels),) or (len(chord_labels), 2) + times : np.ndarray, shape=(len(chord_labels),) or (len(chord_labels), 2) Either the start time of each column in the chromagram, or the time interval corresponding to each column. fs : int Sampling rate to synthesize audio data at - kwargs + **kwargs Additional keyword arguments to pass to :func:`mir_eval.sonify.time_frequency` @@ -298,8 +321,8 @@ def chroma(chromagram, times, fs, **kwargs): # and std 6 (one half octave) mean = 72 std = 6 - notes = np.arange(12*n_octaves) + base_note - shepard_weight = np.exp(-(notes - mean)**2./(2.*std**2.)) + notes = np.arange(12 * n_octaves) + base_note + shepard_weight = np.exp(-((notes - mean) ** 2.0) / (2.0 * std**2.0)) # Copy the chromagram matrix vertically n_octaves times gram = np.tile(chromagram.T, n_octaves).T # This fixes issues if the supplied chromagram is int type @@ -307,7 +330,7 @@ def chroma(chromagram, times, fs, **kwargs): # Apply Sheppard weighting gram *= shepard_weight.reshape(-1, 1) # Compute frequencies - frequencies = 440.0*(2.0**((notes - 69)/12.0)) + frequencies = 440.0 * (2.0 ** ((notes - 69) / 12.0)) return time_frequency(gram, frequencies, times, fs, **kwargs) @@ -322,7 +345,7 @@ def chords(chord_labels, intervals, fs, **kwargs): Start and end times of each chord label fs : int Sampling rate to synthesize at - kwargs + **kwargs Additional keyword arguments to pass to :func:`mir_eval.sonify.time_frequency` @@ -336,8 +359,11 @@ def chords(chord_labels, intervals, fs, **kwargs): # Convert from labels to chroma roots, interval_bitmaps, _ = chord.encode_many(chord_labels) - chromagram = np.array([np.roll(interval_bitmap, root) - for (interval_bitmap, root) - in zip(interval_bitmaps, roots)]).T + chromagram = np.array( + [ + np.roll(interval_bitmap, root) + for (interval_bitmap, root) in zip(interval_bitmaps, roots) + ] + ).T return chroma(chromagram, intervals, fs, **kwargs) diff --git a/mir_eval/tempo.py b/mir_eval/tempo.py index 935e81cf..aa090da8 100644 --- a/mir_eval/tempo.py +++ b/mir_eval/tempo.py @@ -1,4 +1,4 @@ -''' +""" The goal of a tempo estimation algorithm is to automatically detect the tempo of a piece of music, measured in beats per minute (BPM). @@ -18,7 +18,7 @@ * :func:`mir_eval.tempo.detection`: Relative error, hits, and weighted precision of tempo estimation. -''' +""" import warnings import numpy as np @@ -27,42 +27,38 @@ def validate_tempi(tempi, reference=True): - """Checks that there are two non-negative tempi. + """Check that there are two non-negative tempi. For a reference value, at least one tempo has to be greater than zero. Parameters ---------- tempi : np.ndarray length-2 array of tempo, in bpm - reference : bool indicates a reference value - """ - if tempi.size != 2: - raise ValueError('tempi must have exactly two values') + raise ValueError("tempi must have exactly two values") if not np.all(np.isfinite(tempi)) or np.any(tempi < 0): - raise ValueError('tempi={} must be non-negative numbers'.format(tempi)) + raise ValueError("tempi={} must be non-negative numbers".format(tempi)) if reference and np.all(tempi == 0): - raise ValueError('reference tempi={} must have one' - ' value greater than zero'.format(tempi)) + raise ValueError( + "reference tempi={} must have one" " value greater than zero".format(tempi) + ) def validate(reference_tempi, reference_weight, estimated_tempi): - """Checks that the input annotations to a metric look like valid tempo + """Check that the input annotations to a metric look like valid tempo annotations. Parameters ---------- reference_tempi : np.ndarray reference tempo values, in bpm - reference_weight : float perceptual weight of slow vs fast in reference - estimated_tempi : np.ndarray estimated tempo values, in bpm @@ -71,7 +67,7 @@ def validate(reference_tempi, reference_weight, estimated_tempi): validate_tempi(estimated_tempi, reference=False) if reference_weight < 0 or reference_weight > 1: - raise ValueError('Reference weight must lie in range [0, 1]') + raise ValueError("Reference weight must lie in range [0, 1]") def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): @@ -81,14 +77,11 @@ def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): ---------- reference_tempi : np.ndarray, shape=(2,) Two non-negative reference tempi - reference_weight : float > 0 The relative strength of ``reference_tempi[0]`` vs ``reference_tempi[1]``. - estimated_tempi : np.ndarray, shape=(2,) Two non-negative estimated tempi. - tol : float in [0, 1]: The maximum allowable deviation from a reference tempo to count as a hit. @@ -100,10 +93,8 @@ def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): p_score : float in [0, 1] Weighted average of recalls: ``reference_weight * hits[0] + (1 - reference_weight) * hits[1]`` - one_correct : bool True if at least one reference tempo was correctly estimated - both_correct : bool True if both reference tempi were correctly estimated @@ -116,15 +107,14 @@ def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): If ``tol < 0`` or ``tol > 1``. """ - validate(reference_tempi, reference_weight, estimated_tempi) if tol < 0 or tol > 1: - raise ValueError('invalid tolerance {}: must lie in the range ' - '[0, 1]'.format(tol)) - if tol == 0.: - warnings.warn('A tolerance of 0.0 may not ' - 'lead to the results you expect.') + raise ValueError( + "invalid tolerance {}: must lie in the range " "[0, 1]".format(tol) + ) + if tol == 0.0: + warnings.warn("A tolerance of 0.0 may not " "lead to the results you expect.") hits = [False, False] @@ -137,7 +127,7 @@ def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): # Count the hits hits[i] = relative_error <= tol - p_score = reference_weight * hits[0] + (1.0-reference_weight) * hits[1] + p_score = reference_weight * hits[0] + (1.0 - reference_weight) * hits[1] one_correct = bool(np.max(hits)) both_correct = bool(np.min(hits)) @@ -152,15 +142,12 @@ def evaluate(reference_tempi, reference_weight, estimated_tempi, **kwargs): ---------- reference_tempi : np.ndarray, shape=(2,) Two non-negative reference tempi - reference_weight : float > 0 The relative strength of ``reference_tempi[0]`` vs ``reference_tempi[1]``. - estimated_tempi : np.ndarray, shape=(2,) Two non-negative estimated tempi. - - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -173,11 +160,12 @@ def evaluate(reference_tempi, reference_weight, estimated_tempi, **kwargs): # Compute all metrics scores = collections.OrderedDict() - (scores['P-score'], - scores['One-correct'], - scores['Both-correct']) = util.filter_kwargs(detection, reference_tempi, - reference_weight, - estimated_tempi, - **kwargs) + ( + scores["P-score"], + scores["One-correct"], + scores["Both-correct"], + ) = util.filter_kwargs( + detection, reference_tempi, reference_weight, estimated_tempi, **kwargs + ) return scores diff --git a/mir_eval/transcription.py b/mir_eval/transcription.py index db93bd3b..65504279 100644 --- a/mir_eval/transcription.py +++ b/mir_eval/transcription.py @@ -1,4 +1,4 @@ -''' +""" The aim of a transcription algorithm is to produce a symbolic representation of a recorded piece of music in the form of a set of discrete notes. There are different ways to represent notes symbolically. Here we use the piano-roll @@ -102,7 +102,7 @@ account, meaning two notes could be matched even if they have very different pitch values. -''' +""" import numpy as np import collections @@ -115,7 +115,7 @@ def validate(ref_intervals, ref_pitches, est_intervals, est_pitches): - """Checks that the input annotations to a metric look like time intervals + """Check that the input annotations to a metric look like time intervals and a pitch list, and throws helpful errors if not. Parameters @@ -134,23 +134,19 @@ def validate(ref_intervals, ref_pitches, est_intervals, est_pitches): # Make sure intervals and pitches match in length if not ref_intervals.shape[0] == ref_pitches.shape[0]: - raise ValueError('Reference intervals and pitches have different ' - 'lengths.') + raise ValueError("Reference intervals and pitches have different " "lengths.") if not est_intervals.shape[0] == est_pitches.shape[0]: - raise ValueError('Estimated intervals and pitches have different ' - 'lengths.') + raise ValueError("Estimated intervals and pitches have different " "lengths.") # Make sure all pitch values are positive if ref_pitches.size > 0 and np.min(ref_pitches) <= 0: - raise ValueError("Reference contains at least one non-positive pitch " - "value") + raise ValueError("Reference contains at least one non-positive pitch " "value") if est_pitches.size > 0 and np.min(est_pitches) <= 0: - raise ValueError("Estimate contains at least one non-positive pitch " - "value") + raise ValueError("Estimate contains at least one non-positive pitch " "value") def validate_intervals(ref_intervals, est_intervals): - """Checks that the input annotations to a metric look like time intervals, + """Check that the input annotations to a metric look like time intervals, and throws helpful errors if not. Parameters @@ -171,8 +167,13 @@ def validate_intervals(ref_intervals, est_intervals): util.validate_intervals(est_intervals) -def match_note_offsets(ref_intervals, est_intervals, offset_ratio=0.2, - offset_min_tolerance=0.05, strict=False): +def match_note_offsets( + ref_intervals, + est_intervals, + offset_ratio=0.2, + offset_min_tolerance=0.05, + strict=False, +): """Compute a maximum matching between reference and estimated notes, only taking note offsets into account. @@ -229,17 +230,16 @@ def match_note_offsets(ref_intervals, est_intervals, offset_ratio=0.2, cmp_func = np.less_equal # check for offset matches - offset_distances = np.abs(np.subtract.outer(ref_intervals[:, 1], - est_intervals[:, 1])) + offset_distances = np.abs( + np.subtract.outer(ref_intervals[:, 1], est_intervals[:, 1]) + ) # Round distances to a target precision to avoid the situation where # if the distance is exactly 50ms (and strict=False) it erroneously # doesn't match the notes because of precision issues. offset_distances = np.around(offset_distances, decimals=N_DECIMALS) ref_durations = util.intervals_to_durations(ref_intervals) - offset_tolerances = np.maximum(offset_ratio * ref_durations, - offset_min_tolerance) - offset_hit_matrix = ( - cmp_func(offset_distances, offset_tolerances.reshape(-1, 1))) + offset_tolerances = np.maximum(offset_ratio * ref_durations, offset_min_tolerance) + offset_hit_matrix = cmp_func(offset_distances, offset_tolerances.reshape(-1, 1)) # check for hits hits = np.where(offset_hit_matrix) @@ -260,8 +260,7 @@ def match_note_offsets(ref_intervals, est_intervals, offset_ratio=0.2, return matching -def match_note_onsets(ref_intervals, est_intervals, onset_tolerance=0.05, - strict=False): +def match_note_onsets(ref_intervals, est_intervals, onset_tolerance=0.05, strict=False): """Compute a maximum matching between reference and estimated notes, only taking note onsets into account. @@ -306,8 +305,9 @@ def match_note_onsets(ref_intervals, est_intervals, onset_tolerance=0.05, cmp_func = np.less_equal # check for onset matches - onset_distances = np.abs(np.subtract.outer(ref_intervals[:, 0], - est_intervals[:, 0])) + onset_distances = np.abs( + np.subtract.outer(ref_intervals[:, 0], est_intervals[:, 0]) + ) # Round distances to a target precision to avoid the situation where # if the distance is exactly 50ms (and strict=False) it erroneously # doesn't match the notes because of precision issues. @@ -333,9 +333,17 @@ def match_note_onsets(ref_intervals, est_intervals, onset_tolerance=0.05, return matching -def match_notes(ref_intervals, ref_pitches, est_intervals, est_pitches, - onset_tolerance=0.05, pitch_tolerance=50.0, offset_ratio=0.2, - offset_min_tolerance=0.05, strict=False): +def match_notes( + ref_intervals, + ref_pitches, + est_intervals, + est_pitches, + onset_tolerance=0.05, + pitch_tolerance=50.0, + offset_ratio=0.2, + offset_min_tolerance=0.05, + strict=False, +): """Compute a maximum matching between reference and estimated notes, subject to onset, pitch and (optionally) offset constraints. @@ -414,8 +422,9 @@ def match_notes(ref_intervals, ref_pitches, est_intervals, est_pitches, cmp_func = np.less_equal # check for onset matches - onset_distances = np.abs(np.subtract.outer(ref_intervals[:, 0], - est_intervals[:, 0])) + onset_distances = np.abs( + np.subtract.outer(ref_intervals[:, 0], est_intervals[:, 0]) + ) # Round distances to a target precision to avoid the situation where # if the distance is exactly 50ms (and strict=False) it erroneously # doesn't match the notes because of precision issues. @@ -423,23 +432,25 @@ def match_notes(ref_intervals, ref_pitches, est_intervals, est_pitches, onset_hit_matrix = cmp_func(onset_distances, onset_tolerance) # check for pitch matches - pitch_distances = np.abs(1200*np.subtract.outer(np.log2(ref_pitches), - np.log2(est_pitches))) + pitch_distances = np.abs( + 1200 * np.subtract.outer(np.log2(ref_pitches), np.log2(est_pitches)) + ) pitch_hit_matrix = cmp_func(pitch_distances, pitch_tolerance) # check for offset matches if offset_ratio is not None if offset_ratio is not None: - offset_distances = np.abs(np.subtract.outer(ref_intervals[:, 1], - est_intervals[:, 1])) + offset_distances = np.abs( + np.subtract.outer(ref_intervals[:, 1], est_intervals[:, 1]) + ) # Round distances to a target precision to avoid the situation where # if the distance is exactly 50ms (and strict=False) it erroneously # doesn't match the notes because of precision issues. offset_distances = np.around(offset_distances, decimals=N_DECIMALS) ref_durations = util.intervals_to_durations(ref_intervals) - offset_tolerances = np.maximum(offset_ratio * ref_durations, - offset_min_tolerance) - offset_hit_matrix = ( - cmp_func(offset_distances, offset_tolerances.reshape(-1, 1))) + offset_tolerances = np.maximum( + offset_ratio * ref_durations, offset_min_tolerance + ) + offset_hit_matrix = cmp_func(offset_distances, offset_tolerances.reshape(-1, 1)) else: offset_hit_matrix = True @@ -463,11 +474,18 @@ def match_notes(ref_intervals, ref_pitches, est_intervals, est_pitches, return matching -def precision_recall_f1_overlap(ref_intervals, ref_pitches, est_intervals, - est_pitches, onset_tolerance=0.05, - pitch_tolerance=50.0, offset_ratio=0.2, - offset_min_tolerance=0.05, strict=False, - beta=1.0): +def precision_recall_f1_overlap( + ref_intervals, + ref_pitches, + est_intervals, + est_pitches, + onset_tolerance=0.05, + pitch_tolerance=50.0, + offset_ratio=0.2, + offset_min_tolerance=0.05, + strict=False, + beta=1.0, +): """Compute the Precision, Recall and F-measure of correct vs incorrectly transcribed notes, and the Average Overlap Ratio for correctly transcribed notes (see :func:`average_overlap_ratio`). "Correctness" is determined @@ -548,21 +566,25 @@ def precision_recall_f1_overlap(ref_intervals, ref_pitches, est_intervals, validate(ref_intervals, ref_pitches, est_intervals, est_pitches) # When reference notes are empty, metrics are undefined, return 0's if len(ref_pitches) == 0 or len(est_pitches) == 0: - return 0., 0., 0., 0. - - matching = match_notes(ref_intervals, ref_pitches, est_intervals, - est_pitches, onset_tolerance=onset_tolerance, - pitch_tolerance=pitch_tolerance, - offset_ratio=offset_ratio, - offset_min_tolerance=offset_min_tolerance, - strict=strict) - - precision = float(len(matching))/len(est_pitches) - recall = float(len(matching))/len(ref_pitches) + return 0.0, 0.0, 0.0, 0.0 + + matching = match_notes( + ref_intervals, + ref_pitches, + est_intervals, + est_pitches, + onset_tolerance=onset_tolerance, + pitch_tolerance=pitch_tolerance, + offset_ratio=offset_ratio, + offset_min_tolerance=offset_min_tolerance, + strict=strict, + ) + + precision = float(len(matching)) / len(est_pitches) + recall = float(len(matching)) / len(ref_pitches) f_measure = util.f_measure(precision, recall, beta=beta) - avg_overlap_ratio = average_overlap_ratio(ref_intervals, est_intervals, - matching) + avg_overlap_ratio = average_overlap_ratio(ref_intervals, est_intervals, matching) return precision, recall, f_measure, avg_overlap_ratio @@ -608,9 +630,9 @@ def average_overlap_ratio(ref_intervals, est_intervals, matching): for match in matching: ref_int = ref_intervals[match[0]] est_int = est_intervals[match[1]] - overlap_ratio = ( - (min(ref_int[1], est_int[1]) - max(ref_int[0], est_int[0])) / - (max(ref_int[1], est_int[1]) - min(ref_int[0], est_int[0]))) + overlap_ratio = (min(ref_int[1], est_int[1]) - max(ref_int[0], est_int[0])) / ( + max(ref_int[1], est_int[1]) - min(ref_int[0], est_int[0]) + ) ratios.append(overlap_ratio) if len(ratios) == 0: @@ -619,8 +641,9 @@ def average_overlap_ratio(ref_intervals, est_intervals, matching): return np.mean(ratios) -def onset_precision_recall_f1(ref_intervals, est_intervals, - onset_tolerance=0.05, strict=False, beta=1.0): +def onset_precision_recall_f1( + ref_intervals, est_intervals, onset_tolerance=0.05, strict=False, beta=1.0 +): """Compute the Precision, Recall and F-measure of note onsets: an estimated onset is considered correct if it is within +-50ms of a reference onset. Note that this metric completely ignores note offset and note pitch. This @@ -629,7 +652,6 @@ def onset_precision_recall_f1(ref_intervals, est_intervals, different pitches (i.e. notes that would not match with :func:`match_notes`). - Examples -------- >>> ref_intervals, _ = mir_eval.io.load_valued_intervals( @@ -669,21 +691,26 @@ def onset_precision_recall_f1(ref_intervals, est_intervals, validate_intervals(ref_intervals, est_intervals) # When reference notes are empty, metrics are undefined, return 0's if len(ref_intervals) == 0 or len(est_intervals) == 0: - return 0., 0., 0. + return 0.0, 0.0, 0.0 - matching = match_note_onsets(ref_intervals, est_intervals, - onset_tolerance=onset_tolerance, - strict=strict) + matching = match_note_onsets( + ref_intervals, est_intervals, onset_tolerance=onset_tolerance, strict=strict + ) - onset_precision = float(len(matching))/len(est_intervals) - onset_recall = float(len(matching))/len(ref_intervals) + onset_precision = float(len(matching)) / len(est_intervals) + onset_recall = float(len(matching)) / len(ref_intervals) onset_f_measure = util.f_measure(onset_precision, onset_recall, beta=beta) return onset_precision, onset_recall, onset_f_measure -def offset_precision_recall_f1(ref_intervals, est_intervals, offset_ratio=0.2, - offset_min_tolerance=0.05, strict=False, - beta=1.0): +def offset_precision_recall_f1( + ref_intervals, + est_intervals, + offset_ratio=0.2, + offset_min_tolerance=0.05, + strict=False, + beta=1.0, +): """Compute the Precision, Recall and F-measure of note offsets: an estimated offset is considered correct if it is within +-50ms (or 20% of the ref note duration, which ever is greater) of a reference offset. Note @@ -693,7 +720,6 @@ def offset_precision_recall_f1(ref_intervals, est_intervals, offset_ratio=0.2, different pitches (i.e. notes that would not match with :func:`match_notes`). - Examples -------- >>> ref_intervals, _ = mir_eval.io.load_valued_intervals( @@ -740,17 +766,19 @@ def offset_precision_recall_f1(ref_intervals, est_intervals, offset_ratio=0.2, validate_intervals(ref_intervals, est_intervals) # When reference notes are empty, metrics are undefined, return 0's if len(ref_intervals) == 0 or len(est_intervals) == 0: - return 0., 0., 0. - - matching = match_note_offsets(ref_intervals, est_intervals, - offset_ratio=offset_ratio, - offset_min_tolerance=offset_min_tolerance, - strict=strict) - - offset_precision = float(len(matching))/len(est_intervals) - offset_recall = float(len(matching))/len(ref_intervals) - offset_f_measure = util.f_measure(offset_precision, offset_recall, - beta=beta) + return 0.0, 0.0, 0.0 + + matching = match_note_offsets( + ref_intervals, + est_intervals, + offset_ratio=offset_ratio, + offset_min_tolerance=offset_min_tolerance, + strict=strict, + ) + + offset_precision = float(len(matching)) / len(est_intervals) + offset_recall = float(len(matching)) / len(ref_intervals) + offset_f_measure = util.f_measure(offset_precision, offset_recall, beta=beta) return offset_precision, offset_recall, offset_f_measure @@ -776,7 +804,7 @@ def evaluate(ref_intervals, ref_pitches, est_intervals, est_pitches, **kwargs): Array of estimated notes time intervals (onset and offset times) est_pitches : np.ndarray, shape=(m,) Array of estimated pitch values in Hertz - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -790,40 +818,57 @@ def evaluate(ref_intervals, ref_pitches, est_intervals, est_pitches, **kwargs): scores = collections.OrderedDict() # Precision, recall and f-measure taking note offsets into account - kwargs.setdefault('offset_ratio', 0.2) - orig_offset_ratio = kwargs['offset_ratio'] - if kwargs['offset_ratio'] is not None: - (scores['Precision'], - scores['Recall'], - scores['F-measure'], - scores['Average_Overlap_Ratio']) = util.filter_kwargs( - precision_recall_f1_overlap, ref_intervals, ref_pitches, - est_intervals, est_pitches, **kwargs) + kwargs.setdefault("offset_ratio", 0.2) + orig_offset_ratio = kwargs["offset_ratio"] + if kwargs["offset_ratio"] is not None: + ( + scores["Precision"], + scores["Recall"], + scores["F-measure"], + scores["Average_Overlap_Ratio"], + ) = util.filter_kwargs( + precision_recall_f1_overlap, + ref_intervals, + ref_pitches, + est_intervals, + est_pitches, + **kwargs + ) # Precision, recall and f-measure NOT taking note offsets into account - kwargs['offset_ratio'] = None - (scores['Precision_no_offset'], - scores['Recall_no_offset'], - scores['F-measure_no_offset'], - scores['Average_Overlap_Ratio_no_offset']) = ( - util.filter_kwargs(precision_recall_f1_overlap, - ref_intervals, ref_pitches, - est_intervals, est_pitches, **kwargs)) + kwargs["offset_ratio"] = None + ( + scores["Precision_no_offset"], + scores["Recall_no_offset"], + scores["F-measure_no_offset"], + scores["Average_Overlap_Ratio_no_offset"], + ) = util.filter_kwargs( + precision_recall_f1_overlap, + ref_intervals, + ref_pitches, + est_intervals, + est_pitches, + **kwargs + ) # onset-only metrics - (scores['Onset_Precision'], - scores['Onset_Recall'], - scores['Onset_F-measure']) = ( - util.filter_kwargs(onset_precision_recall_f1, - ref_intervals, est_intervals, **kwargs)) + ( + scores["Onset_Precision"], + scores["Onset_Recall"], + scores["Onset_F-measure"], + ) = util.filter_kwargs( + onset_precision_recall_f1, ref_intervals, est_intervals, **kwargs + ) # offset-only metrics - kwargs['offset_ratio'] = orig_offset_ratio - if kwargs['offset_ratio'] is not None: - (scores['Offset_Precision'], - scores['Offset_Recall'], - scores['Offset_F-measure']) = ( - util.filter_kwargs(offset_precision_recall_f1, - ref_intervals, est_intervals, **kwargs)) + kwargs["offset_ratio"] = orig_offset_ratio + if kwargs["offset_ratio"] is not None: + ( + scores["Offset_Precision"], + scores["Offset_Recall"], + scores["Offset_F-measure"], + ) = util.filter_kwargs( + offset_precision_recall_f1, ref_intervals, est_intervals, **kwargs + ) return scores diff --git a/mir_eval/transcription_velocity.py b/mir_eval/transcription_velocity.py index c7aac282..866ac97e 100644 --- a/mir_eval/transcription_velocity.py +++ b/mir_eval/transcription_velocity.py @@ -59,9 +59,15 @@ from . import util -def validate(ref_intervals, ref_pitches, ref_velocities, est_intervals, - est_pitches, est_velocities): - """Checks that the input annotations have valid time intervals, pitches, +def validate( + ref_intervals, + ref_pitches, + ref_velocities, + est_intervals, + est_pitches, + est_velocities, +): + """Check that the input annotations have valid time intervals, pitches, and velocities, and throws helpful errors if not. Parameters @@ -79,27 +85,39 @@ def validate(ref_intervals, ref_pitches, ref_velocities, est_intervals, est_velocities : np.ndarray, shape=(m,) Array of MIDI velocities (i.e. between 0 and 127) of estimated notes """ - transcription.validate(ref_intervals, ref_pitches, est_intervals, - est_pitches) + transcription.validate(ref_intervals, ref_pitches, est_intervals, est_pitches) # Check that velocities have the same length as intervals/pitches if not ref_velocities.shape[0] == ref_pitches.shape[0]: - raise ValueError('Reference velocities must have the same length as ' - 'pitches and intervals.') + raise ValueError( + "Reference velocities must have the same length as " + "pitches and intervals." + ) if not est_velocities.shape[0] == est_pitches.shape[0]: - raise ValueError('Estimated velocities must have the same length as ' - 'pitches and intervals.') + raise ValueError( + "Estimated velocities must have the same length as " + "pitches and intervals." + ) # Check that the velocities are positive if ref_velocities.size > 0 and np.min(ref_velocities) < 0: - raise ValueError('Reference velocities must be positive.') + raise ValueError("Reference velocities must be positive.") if est_velocities.size > 0 and np.min(est_velocities) < 0: - raise ValueError('Estimated velocities must be positive.') + raise ValueError("Estimated velocities must be positive.") def match_notes( - ref_intervals, ref_pitches, ref_velocities, est_intervals, est_pitches, - est_velocities, onset_tolerance=0.05, pitch_tolerance=50.0, - offset_ratio=0.2, offset_min_tolerance=0.05, strict=False, - velocity_tolerance=0.1): + ref_intervals, + ref_pitches, + ref_velocities, + est_intervals, + est_pitches, + est_velocities, + onset_tolerance=0.05, + pitch_tolerance=50.0, + offset_ratio=0.2, + offset_min_tolerance=0.05, + strict=False, + velocity_tolerance=0.1, +): """Match notes, taking note velocity into consideration. This function first calls :func:`mir_eval.transcription.match_notes` to @@ -162,15 +180,22 @@ def match_notes( """ # Compute note matching as usual using standard transcription function matching = transcription.match_notes( - ref_intervals, ref_pitches, est_intervals, est_pitches, - onset_tolerance, pitch_tolerance, offset_ratio, offset_min_tolerance, - strict) + ref_intervals, + ref_pitches, + est_intervals, + est_pitches, + onset_tolerance, + pitch_tolerance, + offset_ratio, + offset_min_tolerance, + strict, + ) # Rescale reference velocities to the range [0, 1] min_velocity, max_velocity = np.min(ref_velocities), np.max(ref_velocities) # Make the smallest possible range 1 to avoid divide by zero velocity_range = max(1, max_velocity - min_velocity) - ref_velocities = (ref_velocities - min_velocity)/float(velocity_range) + ref_velocities = (ref_velocities - min_velocity) / float(velocity_range) # Convert matching list-of-tuples to array for fancy indexing matching = np.array(matching) @@ -183,16 +208,17 @@ def match_notes( # Find slope and intercept of line which produces best least-squares fit # between matched est and ref velocities slope, intercept = np.linalg.lstsq( - np.vstack([est_matched_velocities, - np.ones(len(est_matched_velocities))]).T, - ref_matched_velocities)[0] + np.vstack([est_matched_velocities, np.ones(len(est_matched_velocities))]).T, + ref_matched_velocities, + rcond=None, + )[0] # Re-scale est velocities to match ref - est_matched_velocities = slope*est_matched_velocities + intercept + est_matched_velocities = slope * est_matched_velocities + intercept # Compute the absolute error of (rescaled) estimated velocities vs. # normalized reference velocities. Error will be in [0, 1] velocity_diff = np.abs(est_matched_velocities - ref_matched_velocities) # Check whether each error is within the provided tolerance - velocity_within_tolerance = (velocity_diff < velocity_tolerance) + velocity_within_tolerance = velocity_diff < velocity_tolerance # Only keep matches whose velocity was within the provided tolerance matching = matching[velocity_within_tolerance] # Convert back to list-of-tuple format @@ -202,10 +228,20 @@ def match_notes( def precision_recall_f1_overlap( - ref_intervals, ref_pitches, ref_velocities, est_intervals, est_pitches, - est_velocities, onset_tolerance=0.05, pitch_tolerance=50.0, - offset_ratio=0.2, offset_min_tolerance=0.05, strict=False, - velocity_tolerance=0.1, beta=1.0): + ref_intervals, + ref_pitches, + ref_velocities, + est_intervals, + est_pitches, + est_velocities, + onset_tolerance=0.05, + pitch_tolerance=50.0, + offset_ratio=0.2, + offset_min_tolerance=0.05, + strict=False, + velocity_tolerance=0.1, + beta=1.0, +): """Compute the Precision, Recall and F-measure of correct vs incorrectly transcribed notes, and the Average Overlap Ratio for correctly transcribed notes (see :func:`mir_eval.transcription.average_overlap_ratio`). @@ -282,29 +318,53 @@ def precision_recall_f1_overlap( avg_overlap_ratio : float The computed Average Overlap Ratio score """ - validate(ref_intervals, ref_pitches, ref_velocities, est_intervals, - est_pitches, est_velocities) + validate( + ref_intervals, + ref_pitches, + ref_velocities, + est_intervals, + est_pitches, + est_velocities, + ) # When reference notes are empty, metrics are undefined, return 0's if len(ref_pitches) == 0 or len(est_pitches) == 0: - return 0., 0., 0., 0. + return 0.0, 0.0, 0.0, 0.0 matching = match_notes( - ref_intervals, ref_pitches, ref_velocities, est_intervals, est_pitches, - est_velocities, onset_tolerance, pitch_tolerance, offset_ratio, - offset_min_tolerance, strict, velocity_tolerance) - - precision = float(len(matching))/len(est_pitches) - recall = float(len(matching))/len(ref_pitches) + ref_intervals, + ref_pitches, + ref_velocities, + est_intervals, + est_pitches, + est_velocities, + onset_tolerance, + pitch_tolerance, + offset_ratio, + offset_min_tolerance, + strict, + velocity_tolerance, + ) + + precision = float(len(matching)) / len(est_pitches) + recall = float(len(matching)) / len(ref_pitches) f_measure = util.f_measure(precision, recall, beta=beta) avg_overlap_ratio = transcription.average_overlap_ratio( - ref_intervals, est_intervals, matching) + ref_intervals, est_intervals, matching + ) return precision, recall, f_measure, avg_overlap_ratio -def evaluate(ref_intervals, ref_pitches, ref_velocities, est_intervals, - est_pitches, est_velocities, **kwargs): +def evaluate( + ref_intervals, + ref_pitches, + ref_velocities, + est_intervals, + est_pitches, + est_velocities, + **kwargs +): """Compute all metrics for the given reference and estimated annotations. Parameters @@ -321,7 +381,7 @@ def evaluate(ref_intervals, ref_pitches, ref_velocities, est_intervals, Array of estimated pitch values in Hertz est_velocities : np.ndarray, shape=(n,) Array of MIDI velocities (i.e. between 0 and 127) of estimated notes - kwargs + **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -335,23 +395,40 @@ def evaluate(ref_intervals, ref_pitches, ref_velocities, est_intervals, scores = collections.OrderedDict() # Precision, recall and f-measure taking note offsets into account - kwargs.setdefault('offset_ratio', 0.2) - if kwargs['offset_ratio'] is not None: - (scores['Precision'], - scores['Recall'], - scores['F-measure'], - scores['Average_Overlap_Ratio']) = util.filter_kwargs( - precision_recall_f1_overlap, ref_intervals, ref_pitches, - ref_velocities, est_intervals, est_pitches, est_velocities, - **kwargs) + kwargs.setdefault("offset_ratio", 0.2) + if kwargs["offset_ratio"] is not None: + ( + scores["Precision"], + scores["Recall"], + scores["F-measure"], + scores["Average_Overlap_Ratio"], + ) = util.filter_kwargs( + precision_recall_f1_overlap, + ref_intervals, + ref_pitches, + ref_velocities, + est_intervals, + est_pitches, + est_velocities, + **kwargs + ) # Precision, recall and f-measure NOT taking note offsets into account - kwargs['offset_ratio'] = None - (scores['Precision_no_offset'], - scores['Recall_no_offset'], - scores['F-measure_no_offset'], - scores['Average_Overlap_Ratio_no_offset']) = util.filter_kwargs( - precision_recall_f1_overlap, ref_intervals, ref_pitches, - ref_velocities, est_intervals, est_pitches, est_velocities, **kwargs) + kwargs["offset_ratio"] = None + ( + scores["Precision_no_offset"], + scores["Recall_no_offset"], + scores["F-measure_no_offset"], + scores["Average_Overlap_Ratio_no_offset"], + ) = util.filter_kwargs( + precision_recall_f1_overlap, + ref_intervals, + ref_pitches, + ref_velocities, + est_intervals, + est_pitches, + est_velocities, + **kwargs + ) return scores diff --git a/mir_eval/util.py b/mir_eval/util.py index 6d5df2c1..4c9e2326 100644 --- a/mir_eval/util.py +++ b/mir_eval/util.py @@ -1,7 +1,7 @@ -''' -This submodule collects useful functionality required across the task -submodules, such as preprocessing, validation, and common computations. -''' +""" +Useful functionality required across the task submodules, +such as preprocessing, validation, and common computations. +""" import os import inspect @@ -17,7 +17,6 @@ def index_labels(labels, case_sensitive=False): labels : list of strings, shape=(n,) A list of annotations, e.g., segment or chord labels from an annotation file. - case_sensitive : bool Set to True to enable case-sensitive label indexing (Default value = False) @@ -29,9 +28,7 @@ def index_labels(labels, case_sensitive=False): index_to_label : dict Mapping to convert numerical indices back to labels. ``labels[i] == index_to_label[indices[i]]`` - """ - label_to_index = {} index_to_label = {} @@ -51,7 +48,7 @@ def index_labels(labels, case_sensitive=False): return indices, index_to_label -def generate_labels(items, prefix='__'): +def generate_labels(items, prefix="__"): """Given an array of items (e.g. events, intervals), create a synthetic label for each event of the form '(label prefix)(item number)' @@ -69,11 +66,10 @@ def generate_labels(items, prefix='__'): Synthetically generated labels """ - return ['{}{}'.format(prefix, n) for n in range(len(items))] + return ["{}{}".format(prefix, n) for n in range(len(items))] -def intervals_to_samples(intervals, labels, offset=0, sample_size=0.1, - fill_value=None): +def intervals_to_samples(intervals, labels, offset=0, sample_size=0.1, fill_value=None): """Convert an array of labeled time intervals to annotated samples. Parameters @@ -84,18 +80,14 @@ def intervals_to_samples(intervals, labels, offset=0, sample_size=0.1, :func:`mir_eval.io.load_labeled_intervals()`. The ``i`` th interval spans time ``intervals[i, 0]`` to ``intervals[i, 1]``. - labels : list, shape=(n,) The annotation for each interval - offset : float > 0 Phase offset of the sampled time grid (in seconds) (Default value = 0) - sample_size : float > 0 duration of each sample to be generated (in seconds) (Default value = 0.1) - fill_value : type(labels[0]) Object to use for the label with out-of-range time points. (Default value = None) @@ -104,7 +96,6 @@ def intervals_to_samples(intervals, labels, offset=0, sample_size=0.1, ------- sample_times : list list of sample times - sample_labels : list array of labels for each generated sample @@ -112,15 +103,12 @@ def intervals_to_samples(intervals, labels, offset=0, sample_size=0.1, ----- Intervals will be rounded down to the nearest multiple of ``sample_size``. - """ - # Round intervals to the sample size num_samples = int(np.floor(intervals.max() / sample_size)) sample_indices = np.arange(num_samples, dtype=np.float32) - sample_times = (sample_indices*sample_size + offset).tolist() - sampled_labels = interpolate_intervals( - intervals, labels, sample_times, fill_value) + sample_times = (sample_indices * sample_size + offset).tolist() + sampled_labels = interpolate_intervals(intervals, labels, sample_times, fill_value) return sample_times, sampled_labels @@ -161,33 +149,31 @@ def interpolate_intervals(intervals, labels, time_points, fill_value=None): ValueError If `time_points` is not in non-decreasing order. """ - # Verify that time_points is sorted time_points = np.asarray(time_points) if np.any(time_points[1:] < time_points[:-1]): - raise ValueError('time_points must be in non-decreasing order') + raise ValueError("time_points must be in non-decreasing order") aligned_labels = [fill_value] * len(time_points) - starts = np.searchsorted(time_points, intervals[:, 0], side='left') - ends = np.searchsorted(time_points, intervals[:, 1], side='right') + starts = np.searchsorted(time_points, intervals[:, 0], side="left") + ends = np.searchsorted(time_points, intervals[:, 1], side="right") - for (start, end, lab) in zip(starts, ends, labels): + for start, end, lab in zip(starts, ends, labels): aligned_labels[start:end] = [lab] * (end - start) return aligned_labels def sort_labeled_intervals(intervals, labels=None): - '''Sort intervals, and optionally, their corresponding labels + """Sort intervals, and optionally, their corresponding labels according to start time. Parameters ---------- intervals : np.ndarray, shape=(n, 2) The input intervals - labels : list, optional Labels for each interval @@ -195,8 +181,7 @@ def sort_labeled_intervals(intervals, labels=None): ------- intervals_sorted or (intervals_sorted, labels_sorted) Labels are only returned if provided as input - ''' - + """ idx = np.argsort(intervals[:, 0]) intervals_sorted = intervals[idx] @@ -224,13 +209,11 @@ def f_measure(precision, recall, beta=1.0): ------- f_measure : float The weighted f-measure - """ - if precision == 0 and recall == 0: return 0.0 - return (1 + beta**2)*precision*recall/((beta**2)*precision + recall) + return (1 + beta**2) * precision * recall / ((beta**2) * precision + recall) def intervals_to_boundaries(intervals, q=5): @@ -247,9 +230,7 @@ def intervals_to_boundaries(intervals, q=5): ------- boundaries : np.ndarray Interval boundary times, including the end of the final interval - """ - return np.unique(np.ravel(np.round(intervals, decimals=q))) @@ -267,21 +248,22 @@ def boundaries_to_intervals(boundaries): intervals : np.ndarray, shape=(n_intervals, 2) Start and end time for each interval """ - if not np.allclose(boundaries, np.unique(boundaries)): - raise ValueError('Boundary times are not unique or not ascending.') + raise ValueError("Boundary times are not unique or not ascending.") intervals = np.asarray(list(zip(boundaries[:-1], boundaries[1:]))) return intervals -def adjust_intervals(intervals, - labels=None, - t_min=0.0, - t_max=None, - start_label='__T_MIN', - end_label='__T_MAX'): +def adjust_intervals( + intervals, + labels=None, + t_min=0.0, + t_max=None, + start_label="__T_MIN", + end_label="__T_MAX", +): """Adjust a list of time intervals to span the range ``[t_min, t_max]``. Any intervals lying completely outside the specified range will be removed. @@ -320,9 +302,7 @@ def adjust_intervals(intervals, Intervals spanning ``[t_min, t_max]`` new_labels : list List of labels for ``new_labels`` - """ - # When supplied intervals are empty and t_max and t_min are supplied, # create one interval from t_min to t_max with the label start_label if t_min is not None and t_max is not None and intervals.size == 0: @@ -330,8 +310,7 @@ def adjust_intervals(intervals, # When intervals are empty and either t_min or t_max are not supplied, # we can't append new intervals elif (t_min is None or t_max is None) and intervals.size == 0: - raise ValueError("Supplied intervals are empty, can't append new" - " intervals") + raise ValueError("Supplied intervals are empty, can't append new" " intervals") if t_min is not None: # Find the intervals that end at or after t_min @@ -340,9 +319,9 @@ def adjust_intervals(intervals, if len(first_idx) > 0: # If we have events below t_min, crop them out if labels is not None: - labels = labels[int(first_idx[0]):] + labels = labels[first_idx[0, 0] :] # Clip to the range (t_min, +inf) - intervals = intervals[int(first_idx[0]):] + intervals = intervals[first_idx[0, 0] :] intervals = np.maximum(t_min, intervals) if intervals.min() > t_min: @@ -360,9 +339,9 @@ def adjust_intervals(intervals, # We have boundaries above t_max. # Trim to only boundaries <= t_max if labels is not None: - labels = labels[:int(last_idx[0])] + labels = labels[: last_idx[0, 0]] # Clip to the range (-inf, t_max) - intervals = intervals[:int(last_idx[0])] + intervals = intervals[: last_idx[0, 0]] intervals = np.minimum(t_max, intervals) @@ -375,8 +354,7 @@ def adjust_intervals(intervals, return intervals, labels -def adjust_events(events, labels=None, t_min=0.0, - t_max=None, label_prefix='__'): +def adjust_events(events, labels=None, t_min=0.0, t_max=None, label_prefix="__"): """Adjust the given list of event times to span the range ``[t_min, t_max]``. @@ -415,15 +393,15 @@ def adjust_events(events, labels=None, t_min=0.0, # We have events below t_min # Crop them out if labels is not None: - labels = labels[int(first_idx[0]):] - events = events[int(first_idx[0]):] + labels = labels[first_idx[0, 0] :] + events = events[first_idx[0, 0] :] if events[0] > t_min: # Lowest boundary is higher than t_min: # add a new boundary and label events = np.concatenate(([t_min], events)) if labels is not None: - labels.insert(0, '%sT_MIN' % label_prefix) + labels.insert(0, "%sT_MIN" % label_prefix) if t_max is not None: last_idx = np.argwhere(events > t_max) @@ -432,14 +410,14 @@ def adjust_events(events, labels=None, t_min=0.0, # We have boundaries above t_max. # Trim to only boundaries <= t_max if labels is not None: - labels = labels[:int(last_idx[0])] - events = events[:int(last_idx[0])] + labels = labels[: last_idx[0, 0]] + events = events[: last_idx[0, 0]] if events[-1] < t_max: # Last boundary is below t_max: add a new boundary and label events = np.concatenate((events, [t_max])) if labels is not None: - labels.append('%sT_MAX' % label_prefix) + labels.append("%sT_MAX" % label_prefix) return events, labels @@ -473,16 +451,17 @@ def intersect_files(flist1, flist2): corresponding filepaths from ``flist2`` """ + def fname(abs_path): - """Returns the filename given an absolute path. + """Return the filename given an absolute path. Parameters ---------- - abs_path : - + abs_path Returns ------- + filename """ return os.path.splitext(os.path.split(abs_path)[-1])[0] @@ -521,16 +500,17 @@ def merge_labeled_intervals(x_intervals, x_labels, y_intervals, y_labels): New labels for the sequence ``y`` """ - align_check = [x_intervals[0, 0] == y_intervals[0, 0], - x_intervals[-1, 1] == y_intervals[-1, 1]] + align_check = [ + x_intervals[0, 0] == y_intervals[0, 0], + x_intervals[-1, 1] == y_intervals[-1, 1], + ] if False in align_check: raise ValueError( "Time intervals do not align; did you mean to call " - "'adjust_intervals()' first?") - time_boundaries = np.unique( - np.concatenate([x_intervals, y_intervals], axis=0)) - output_intervals = np.array( - [time_boundaries[:-1], time_boundaries[1:]]).T + "'adjust_intervals()' first?" + ) + time_boundaries = np.unique(np.concatenate([x_intervals, y_intervals], axis=0)) + output_intervals = np.array([time_boundaries[:-1], time_boundaries[1:]]).T x_labels_out, y_labels_out = [], [] x_label_range = np.arange(len(x_labels)) @@ -710,7 +690,7 @@ def match_events(ref, est, window, distance=None): def _fast_hit_windows(ref, est, window): - '''Fast calculation of windowed hits for time events. + """Fast calculation of windowed hits for time events. Given two lists of event times ``ref`` and ``est``, and a tolerance window, computes a list of pairings @@ -735,15 +715,14 @@ def _fast_hit_windows(ref, est, window): hit_ref : np.ndarray hit_est : np.ndarray indices such that ``|hit_ref[i] - hit_est[i]| <= window`` - ''' - + """ ref = np.asarray(ref) est = np.asarray(est) ref_idx = np.argsort(ref) ref_sorted = ref[ref_idx] - left_idx = np.searchsorted(ref_sorted, est - window, side='left') - right_idx = np.searchsorted(ref_sorted, est + window, side='right') + left_idx = np.searchsorted(ref_sorted, est - window, side="left") + right_idx = np.searchsorted(ref_sorted, est + window, side="right") hit_ref, hit_est = [], [] @@ -755,32 +734,32 @@ def _fast_hit_windows(ref, est, window): def validate_intervals(intervals): - """Checks that an (n, 2) interval ndarray is well-formed, and raises errors + """Check that an (n, 2) interval ndarray is well-formed, and raises errors if not. Parameters ---------- intervals : np.ndarray, shape=(n, 2) Array of interval start/end locations. - """ - # Validate interval shape if intervals.ndim != 2 or intervals.shape[1] != 2: - raise ValueError('Intervals should be n-by-2 numpy ndarray, ' - 'but shape={}'.format(intervals.shape)) + raise ValueError( + "Intervals should be n-by-2 numpy ndarray, " + "but shape={}".format(intervals.shape) + ) # Make sure no times are negative if (intervals < 0).any(): - raise ValueError('Negative interval times found') + raise ValueError("Negative interval times found") # Make sure all intervals have strictly positive duration if (intervals[:, 1] <= intervals[:, 0]).any(): - raise ValueError('All interval durations must be strictly positive') + raise ValueError("All interval durations must be strictly positive") -def validate_events(events, max_time=30000.): - """Checks that a 1-d event location ndarray is well-formed, and raises +def validate_events(events, max_time=30000.0): + """Check that a 1-d event location ndarray is well-formed, and raises errors if not. Parameters @@ -790,26 +769,28 @@ def validate_events(events, max_time=30000.): max_time : float If an event is found above this time, a ValueError will be raised. (Default value = 30000.) - """ # Make sure no event times are huge if (events > max_time).any(): - raise ValueError('An event at time {} was found which is greater than ' - 'the maximum allowable time of max_time = {} (did you' - ' supply event times in ' - 'seconds?)'.format(events.max(), max_time)) + raise ValueError( + "An event at time {} was found which is greater than " + "the maximum allowable time of max_time = {} (did you" + " supply event times in " + "seconds?)".format(events.max(), max_time) + ) # Make sure event locations are 1-d np ndarrays if events.ndim != 1: - raise ValueError('Event times should be 1-d numpy ndarray, ' - 'but shape={}'.format(events.shape)) + raise ValueError( + "Event times should be 1-d numpy ndarray, " + "but shape={}".format(events.shape) + ) # Make sure event times are increasing if (np.diff(events) < 0).any(): - raise ValueError('Events should be in increasing order.') + raise ValueError("Events should be in increasing order.") -def validate_frequencies(frequencies, max_freq, min_freq, - allow_negatives=False): - """Checks that a 1-d frequency ndarray is well-formed, and raises +def validate_frequencies(frequencies, max_freq, min_freq, allow_negatives=False): + """Check that a 1-d frequency ndarray is well-formed, and raises errors if not. Parameters @@ -830,24 +811,30 @@ def validate_frequencies(frequencies, max_freq, min_freq, frequencies = np.abs(frequencies) # Make sure no frequency values are huge if (np.abs(frequencies) > max_freq).any(): - raise ValueError('A frequency of {} was found which is greater than ' - 'the maximum allowable value of max_freq = {} (did ' - 'you supply frequency values in ' - 'Hz?)'.format(frequencies.max(), max_freq)) + raise ValueError( + "A frequency of {} was found which is greater than " + "the maximum allowable value of max_freq = {} (did " + "you supply frequency values in " + "Hz?)".format(frequencies.max(), max_freq) + ) # Make sure no frequency values are tiny if (np.abs(frequencies) < min_freq).any(): - raise ValueError('A frequency of {} was found which is less than the ' - 'minimum allowable value of min_freq = {} (did you ' - 'supply frequency values in ' - 'Hz?)'.format(frequencies.min(), min_freq)) + raise ValueError( + "A frequency of {} was found which is less than the " + "minimum allowable value of min_freq = {} (did you " + "supply frequency values in " + "Hz?)".format(frequencies.min(), min_freq) + ) # Make sure frequency values are 1-d np ndarrays if frequencies.ndim != 1: - raise ValueError('Frequencies should be 1-d numpy ndarray, ' - 'but shape={}'.format(frequencies.shape)) + raise ValueError( + "Frequencies should be 1-d numpy ndarray, " + "but shape={}".format(frequencies.shape) + ) def has_kwargs(function): - r'''Determine whether a function has \*\*kwargs. + r"""Determine whether a function has \*\*kwargs. Parameters ---------- @@ -858,8 +845,7 @@ def has_kwargs(function): ------- True if function accepts arbitrary keyword arguments. False otherwise. - ''' - + """ sig = inspect.signature(function) for param in list(sig.parameters.values()): @@ -870,7 +856,7 @@ def has_kwargs(function): def filter_kwargs(_function, *args, **kwargs): - """Given a function and args and keyword args to pass to it, call the function + r"""Given a function and args and keyword args to pass to it, call the function but using only the keyword arguments which it accepts. This is equivalent to redefining the function with an additional \*\*kwargs to accept slop keyword args. @@ -882,15 +868,16 @@ def filter_kwargs(_function, *args, **kwargs): ---------- _function : callable Function to call. Can take in any number of args or kwargs - + *args + **kwargs + Arguments and keyword arguments to _function. """ - if has_kwargs(_function): return _function(*args, **kwargs) # Get the list of function arguments func_code = _function.__code__ - function_args = func_code.co_varnames[:func_code.co_argcount] + function_args = func_code.co_varnames[: func_code.co_argcount] # Construct a dict of those kwargs which appear in the function filtered_kwargs = {} for kwarg, value in list(kwargs.items()): @@ -901,7 +888,7 @@ def filter_kwargs(_function, *args, **kwargs): def intervals_to_durations(intervals): - """Converts an array of n intervals to their n durations. + """Convert an array of n intervals to their n durations. Parameters ---------- @@ -922,7 +909,7 @@ def intervals_to_durations(intervals): def hz_to_midi(freqs): - '''Convert Hz to MIDI numbers + """Convert Hz to MIDI numbers Parameters ---------- @@ -934,12 +921,12 @@ def hz_to_midi(freqs): midi : number or ndarray MIDI note numbers corresponding to input frequencies. Note that these may be fractional. - ''' + """ return 12.0 * (np.log2(freqs) - np.log2(440.0)) + 69.0 def midi_to_hz(midi): - '''Convert MIDI numbers to Hz + """Convert MIDI numbers to Hz Parameters ---------- @@ -950,5 +937,5 @@ def midi_to_hz(midi): ------- freqs : number or ndarray Frequency/frequencies in Hz corresponding to `midi` - ''' - return 440.0 * (2.0 ** ((midi - 69.0)/12.0)) + """ + return 440.0 * (2.0 ** ((midi - 69.0) / 12.0)) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..fbf4fd5e --- /dev/null +++ b/setup.cfg @@ -0,0 +1,60 @@ +[tool:pytest] +addopts = --cov-report term-missing --cov mir_eval --cov-report=xml + +[pydocstyle] +# convention = numpy +# Below is equivalent to numpy convention + D400 and D205 +ignore = D107,D203,D205,D212,D213,D400,D402,D413,D415,D416,D417 + +[flake8] +count = True +statistics = True +show_source = True +select = + E9, + F63, + F7, + F82 + +[metadata] +name = mir_eval +version = attr: mir_eval.__version__ +description = Common metrics for common audio/music processing tasks. +author = Colin Raffel +author_email = craffel@gmail.com +url = https://github.com/craffel/mir_eval +long_description = file: README.rst +long_description_content_type = text/x-rst; charset=UTF-8 +license = MIT +python_requires = ">=3.7" +classifiers = + License :: OSI Approved :: MIT License + Programming Language :: Python + Development Status :: 5 - Production/Stable + Intended Audience :: Developers + Topic :: Multimedia :: Sound/Audio :: Analysis + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 + + +[options] +packages = find: +keywords = audio music mir dsp +install_requires = + numpy >= 1.15.4 + scipy >= 1.4.0 + +[options.extras_require] +display = + matplotlib >= 3.3.0 +testing = + matplotlib >= 3.3.0 + decorator + pytest + pytest-cov + pytest-mpl diff --git a/setup.py b/setup.py index 466cb66a..6b40b52b 100644 --- a/setup.py +++ b/setup.py @@ -1,39 +1,4 @@ from setuptools import setup -with open('README.rst') as file: - long_description = file.read() - -setup( - name='mir_eval', - version='0.7', - description='Common metrics for common audio/music processing tasks.', - author='Colin Raffel', - author_email='craffel@gmail.com', - url='https://github.com/craffel/mir_eval', - packages=['mir_eval'], - long_description=long_description, - classifiers=[ - "License :: OSI Approved :: MIT License", - "Programming Language :: Python", - 'Development Status :: 5 - Production/Stable', - "Intended Audience :: Developers", - "Topic :: Multimedia :: Sound/Audio :: Analysis", - "Programming Language :: Python :: 3", - ], - keywords='audio music mir dsp', - license='MIT', - install_requires=[ - 'numpy >= 1.7.0', - 'scipy >= 1.0.0', - ], - extras_require={ - 'display': ['matplotlib>=1.5.0'], - 'testing': ['matplotlib>=2.1.0', - 'decorator', - 'pytest', - 'pytest-cov', - 'pytest-mpl', - 'nose'] - }, - python_requires='>=3', -) +if __name__ == '__main__': + setup() diff --git a/tests/data/transcription_velocity/output2.json b/tests/data/transcription_velocity/output2.json index 7e4ee32f..57e7b478 100644 --- a/tests/data/transcription_velocity/output2.json +++ b/tests/data/transcription_velocity/output2.json @@ -1 +1 @@ -{"Precision": 0.1761055081458495, "Recall": 0.15655172413793103, "F-measure": 0.16575392479006937, "Average_Overlap_Ratio": 0.6339212298653343, "Precision_no_offset": 0.5865011636927852, "Recall_no_offset": 0.5213793103448275, "F-measure_no_offset": 0.552026286966046, "Average_Overlap_Ratio_no_offset": 0.4720752936805029} \ No newline at end of file +{"Precision": 0.17532971295577968, "Recall": 0.15586206896551724, "F-measure": 0.16502373128879153, "Average_Overlap_Ratio": 0.6346439422322657, "Precision_no_offset": 0.5865011636927852, "Recall_no_offset": 0.5213793103448275, "F-measure_no_offset": 0.552026286966046, "Average_Overlap_Ratio_no_offset": 0.4720752936805029} \ No newline at end of file diff --git a/tests/generate_data.py b/tests/generate_data.py index 2fcc19d1..b0f7496b 100755 --- a/tests/generate_data.py +++ b/tests/generate_data.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -''' +""" Generate data for regression tests. This is a pretty specialized file and should probably only be used if you know what you're doing. @@ -17,7 +17,7 @@ task1, task2, etc. are the tasks you'd like to generate data for. So, for example, if you'd like to generate data for onset and melody,run ./generate_data.py onset melody -''' +""" import mir_eval @@ -29,17 +29,17 @@ def load_separation_data(folder): - ''' + """ Loads in a stacked matrix of the .wavs in the provided folder. We need this because there's no specialized loader in .io for it. - ''' + """ data = [] global_fs = None # Load in each reference file in the supplied dir - for reference_file in glob.glob(os.path.join(folder, '*.wav')): + for reference_file in glob.glob(os.path.join(folder, "*.wav")): audio_data, fs = mir_eval.io.load_wav(reference_file) # Make sure fs is the same for all files - assert (global_fs is None or fs == global_fs) + assert global_fs is None or fs == global_fs global_fs = fs data.append(audio_data) return np.vstack(data) @@ -48,7 +48,8 @@ def load_separation_data(folder): def load_transcription_velocity(filename): """Loader for data in the format start, end, pitch, velocity.""" starts, ends, pitches, velocities = mir_eval.io.load_delimited( - filename, [float, float, int, int]) + filename, [float, float, int, int] + ) # Stack into an interval matrix intervals = np.array([starts, ends]).T # return pitches and velocities as np.ndarray @@ -58,43 +59,61 @@ def load_transcription_velocity(filename): return intervals, pitches, velocities -if __name__ == '__main__': +if __name__ == "__main__": # This dict will contain tuples of (submodule, loader, glob path) # The keys are 'beat', 'chord', etc. # Whatever is passed in as argv will be grabbed from it and the data for # that task will be generated. tasks = {} - tasks['beat'] = (mir_eval.beat, mir_eval.io.load_events, - 'data/beat/{}*.txt') - tasks['chord'] = (mir_eval.chord, mir_eval.io.load_labeled_intervals, - 'data/chord/{}*.lab') - tasks['melody'] = (mir_eval.melody, mir_eval.io.load_time_series, - 'data/melody/{}*.txt') - tasks['multipitch'] = (mir_eval.multipitch, - mir_eval.io.load_ragged_time_series, - 'data/multipitch/()*.txt') - tasks['onset'] = (mir_eval.onset, mir_eval.io.load_events, - 'data/onset/{}*.txt') - tasks['pattern'] = (mir_eval.pattern, mir_eval.io.load_patterns, - 'data/pattern/{}*.txt') - tasks['segment'] = (mir_eval.segment, mir_eval.io.load_labeled_intervals, - 'data/segment/{}*.lab') - tasks['separation'] = (mir_eval.separation, load_separation_data, - 'data/separation/{}*') - tasks['transcription'] = (mir_eval.transcription, - mir_eval.io.load_valued_intervals, - 'data/transcription/{}*.txt') - tasks['transcription_velocity'] = (mir_eval.transcription_velocity, - load_transcription_velocity, - 'data/transcription_velocity/{}*.txt') - tasks['key'] = (mir_eval.key, mir_eval.io.load_key, - 'data/key/{}*.txt') + tasks["beat"] = (mir_eval.beat, mir_eval.io.load_events, "data/beat/{}*.txt") + tasks["chord"] = ( + mir_eval.chord, + mir_eval.io.load_labeled_intervals, + "data/chord/{}*.lab", + ) + tasks["melody"] = ( + mir_eval.melody, + mir_eval.io.load_time_series, + "data/melody/{}*.txt", + ) + tasks["multipitch"] = ( + mir_eval.multipitch, + mir_eval.io.load_ragged_time_series, + "data/multipitch/()*.txt", + ) + tasks["onset"] = (mir_eval.onset, mir_eval.io.load_events, "data/onset/{}*.txt") + tasks["pattern"] = ( + mir_eval.pattern, + mir_eval.io.load_patterns, + "data/pattern/{}*.txt", + ) + tasks["segment"] = ( + mir_eval.segment, + mir_eval.io.load_labeled_intervals, + "data/segment/{}*.lab", + ) + tasks["separation"] = ( + mir_eval.separation, + load_separation_data, + "data/separation/{}*", + ) + tasks["transcription"] = ( + mir_eval.transcription, + mir_eval.io.load_valued_intervals, + "data/transcription/{}*.txt", + ) + tasks["transcription_velocity"] = ( + mir_eval.transcription_velocity, + load_transcription_velocity, + "data/transcription_velocity/{}*.txt", + ) + tasks["key"] = (mir_eval.key, mir_eval.io.load_key, "data/key/{}*.txt") # Get task keys from argv for task in sys.argv[1:]: - print('Generating data for {}'.format(task)) + print("Generating data for {}".format(task)) submodule, loader, data_glob = tasks[task] - ref_files = sorted(glob.glob(data_glob.format('ref'))) - est_files = sorted(glob.glob(data_glob.format('est'))) + ref_files = sorted(glob.glob(data_glob.format("ref"))) + est_files = sorted(glob.glob(data_glob.format("est"))) # Cycle through annotation file pairs for ref_file, est_file in zip(ref_files, est_files): # Use the loader to load in data @@ -106,7 +125,7 @@ def load_transcription_velocity(filename): else: scores = submodule.evaluate(ref_data, est_data) # Write out the resulting scores dict - output_file = ref_file.replace('ref', 'output') - output_file = os.path.splitext(output_file)[0] + '.json' - with open(output_file, 'w') as f: + output_file = ref_file.replace("ref", "output") + output_file = os.path.splitext(output_file)[0] + ".json" + with open(output_file, "w") as f: json.dump(scores, f) diff --git a/tests/test_alignment.py b/tests/test_alignment.py index 4bccc493..b5c2121d 100644 --- a/tests/test_alignment.py +++ b/tests/test_alignment.py @@ -5,7 +5,7 @@ import glob import json -import nose.tools +import pytest import numpy as np import mir_eval @@ -17,40 +17,19 @@ EST_GLOB = "data/alignment/est*.txt" SCORES_GLOB = "data/alignment/output*.json" +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) -def __unit_test_alignment_function(metric): - # Now test validation function - # alignments must be 1d ndarray - alignments = np.array([[1.0, 2.0]]) - nose.tools.assert_raises(ValueError, metric, alignments, alignments) - # alignments must be in seconds, and therefore not negative - alignments = np.array([-1.0, 2.0]) - nose.tools.assert_raises(ValueError, metric, alignments, alignments) - # alignments must be sorted - alignments = np.array([2.0, 1.0]) - nose.tools.assert_raises(ValueError, metric, alignments, alignments) - # predicted and estimated alignments must have same length - pred_alignments = np.array([[1.0, 2.0]]) - est_alignments = np.array([[1.0]]) - nose.tools.assert_raises( - ValueError, metric, est_alignments, pred_alignments - ) +assert len(ref_files) == len(est_files) == len(sco_files) > 0 +file_sets = list(zip(ref_files, est_files, sco_files)) -def __check_score(sco_f, metric, score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) - -def test_alignment_functions(): - # Load in all files in the same order - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - assert len(ref_files) == len(est_files) == len(sco_files) > 0 - - # Unit tests - for metric in [ +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "metric", + [ mir_eval.alignment.absolute_error, mir_eval.alignment.percentage_correct, mir_eval.alignment.percentage_correct_segments, @@ -60,27 +39,46 @@ def test_alignment_functions(): ) ), mir_eval.alignment.karaoke_perceptual_metric, - ]: - yield (__unit_test_alignment_function, metric) - # Regression tests - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, "r") as f: - expected_scores = json.load(f) - # Load in an example alignment annotation - reference_alignments = mir_eval.io.load_events(ref_f) - # Load in an example alignment tracker output - estimated_alignments = mir_eval.io.load_events(est_f) - # Compute scores - scores = mir_eval.alignment.evaluate( - reference_alignments, estimated_alignments - ) - # Compare them - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield ( - __check_score, - sco_f, - metric, - scores[metric], - expected_scores[metric], - ) + ], +) +@pytest.mark.parametrize( + "est_alignment, pred_alignment", + [ + ( + np.array([[1.0, 2.0]]), + np.array([[1.0, 2.0]]), + ), # alignments must be 1d ndarray + ( + np.array([[-1.0, 2.0]]), + np.array([[1.0, 2.0]]), + ), # alignments must be non-negative + (np.array([[2.0, 1.0]]), np.array([[1.0, 2.0]])), # alignments must be sorted + ( + np.array([[1.0, 2.0]]), + np.array([[1.0]]), + ), # alignments must have the same length + ], +) +def test_alignment_functions_fail(metric, est_alignment, pred_alignment): + metric(est_alignment, pred_alignment) + + +@pytest.fixture +def alignment_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) + reference_alignments = mir_eval.io.load_events(ref_f) + estimated_alignments = mir_eval.io.load_events(est_f) + + return reference_alignments, estimated_alignments, expected_scores + + +@pytest.mark.parametrize("alignment_data", file_sets, indirect=True) +def test_alignment_functions(alignment_data): + reference_alignments, estimated_alignments, expected_scores = alignment_data + scores = mir_eval.alignment.evaluate(reference_alignments, estimated_alignments) + + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) diff --git a/tests/test_beat.py b/tests/test_beat.py index a759091d..f86f3f38 100644 --- a/tests/test_beat.py +++ b/tests/test_beat.py @@ -1,20 +1,38 @@ -''' +""" Unit tests for mir_eval.beat -''' +""" import numpy as np import json import mir_eval import glob -import warnings -import nose.tools +import pytest A_TOL = 1e-12 # Path to the fixture files -REF_GLOB = 'data/beat/ref*.txt' -EST_GLOB = 'data/beat/est*.txt' -SCORES_GLOB = 'data/beat/output*.json' +REF_GLOB = "data/beat/ref*.txt" +EST_GLOB = "data/beat/est*.txt" +SCORES_GLOB = "data/beat/output*.json" + +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) + +assert len(ref_files) == len(est_files) == len(sco_files) > 0 + +file_sets = list(zip(ref_files, est_files, sco_files)) + + +@pytest.fixture +def beat_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) + reference_beats = mir_eval.io.load_events(ref_f) + estimated_beats = mir_eval.io.load_events(est_f) + + return reference_beats, estimated_beats, expected_scores def test_trim_beats(): @@ -25,71 +43,76 @@ def test_trim_beats(): assert np.allclose(mir_eval.beat.trim_beats(dummy_beats), expected_beats) -def __unit_test_beat_function(metric): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - # First, test for a warning on empty beats +@pytest.mark.parametrize( + "metric", + [ + mir_eval.beat.f_measure, + mir_eval.beat.cemgil, + mir_eval.beat.goto, + mir_eval.beat.p_score, + mir_eval.beat.continuity, + mir_eval.beat.information_gain, + ], +) +def test_beat_empty_warnings(metric): + with pytest.warns(UserWarning, match="Reference beats are empty."): metric(np.array([]), np.arange(10)) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == "Reference beats are empty." + with pytest.warns(UserWarning, match="Estimated beats are empty."): metric(np.arange(10), np.array([])) - assert len(w) == 2 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == "Estimated beats are empty." - # And that the metric is 0 + with pytest.warns(UserWarning, match="beats are empty."): assert np.allclose(metric(np.array([]), np.array([])), 0) - # Now test validation function - beats must be 1d ndarray - beats = np.array([[1., 2.]]) - nose.tools.assert_raises(ValueError, metric, beats, beats) - # Beats must be in seconds (so not huge) - beats = np.array([1e10, 1e11]) - nose.tools.assert_raises(ValueError, metric, beats, beats) - # Beats must be sorted - beats = np.array([2., 1.]) - nose.tools.assert_raises(ValueError, metric, beats, beats) - - # Valid beats which are the same produce a score of 1 for all metrics + +@pytest.mark.parametrize( + "metric", + [ + mir_eval.beat.f_measure, + mir_eval.beat.cemgil, + mir_eval.beat.goto, + mir_eval.beat.p_score, + mir_eval.beat.continuity, + mir_eval.beat.information_gain, + ], +) +@pytest.mark.parametrize( + "beats", + [ + np.array([[1.0, 2.0]]), # beats must be a 1d array + np.array([1e10, 1e11]), # beats must be not huge + np.array([2.0, 1.0]), # must be sorted + ], +) +@pytest.mark.xfail(raises=ValueError) +def test_beat_fail(metric, beats): + metric(beats, beats) + + +@pytest.mark.parametrize( + "metric", + [ + mir_eval.beat.f_measure, + mir_eval.beat.cemgil, + mir_eval.beat.goto, + mir_eval.beat.p_score, + mir_eval.beat.continuity, + mir_eval.beat.information_gain, + ], +) +def test_beat_perfect(metric): beats = np.arange(10, dtype=np.float64) assert np.allclose(metric(beats, beats), 1) -def __check_score(sco_f, metric, score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) - - -def test_beat_functions(): - # Load in all files in the same order - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - assert len(ref_files) == len(est_files) == len(sco_files) > 0 - - # Unit tests - for metric in [mir_eval.beat.f_measure, - mir_eval.beat.cemgil, - mir_eval.beat.goto, - mir_eval.beat.p_score, - mir_eval.beat.continuity, - mir_eval.beat.information_gain]: - yield (__unit_test_beat_function, metric) - # Regression tests - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_scores = json.load(f) - # Load in an example beat annotation - reference_beats = mir_eval.io.load_events(ref_f) - # Load in an example beat tracker output - estimated_beats = mir_eval.io.load_events(est_f) - # Compute scores - scores = mir_eval.beat.evaluate(reference_beats, estimated_beats) - # Compare them - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, scores[metric], - expected_scores[metric]) +@pytest.mark.parametrize("beat_data", file_sets, indirect=True) +def test_beat_functions(beat_data): + reference_beats, estimated_beats, expected_scores = beat_data + + # Compute scores + scores = mir_eval.beat.evaluate(reference_beats, estimated_beats) + # Compare them + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) # Unit tests for specific behavior not covered by the above @@ -97,36 +120,29 @@ def test_goto_proportion_correct(): # This covers the case when over 75% of the beat tracking is correct, and # more than 3 beats are incorrect assert mir_eval.beat.goto( - np.arange(100), np.append(np.arange(80), np.arange(80, 100) + .2)) + np.arange(100), np.append(np.arange(80), np.arange(80, 100) + 0.2) + ) -def test_warning_on_one_beat(): +@pytest.mark.parametrize( + "metric", + [mir_eval.beat.p_score, mir_eval.beat.continuity, mir_eval.beat.information_gain], +) +def test_warning_on_one_beat(metric): # This tests the metrics where passing only a single beat raises a warning # and returns 0 - for metric in [mir_eval.beat.p_score, mir_eval.beat.continuity, - mir_eval.beat.information_gain]: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - # First, test for a warning on empty beats - metric(np.array([10]), np.arange(10)) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == ( - "Only one reference beat was provided, so beat intervals " - "cannot be computed.") - metric(np.arange(10), np.array([10.])) - assert len(w) == 2 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == ( - "Only one estimated beat was provided, so beat intervals " - "cannot be computed.") - # And that the metric is 0 - assert np.allclose(metric(np.array([]), np.array([])), 0) + + with pytest.warns(UserWarning, match="Only one reference beat"): + metric(np.array([10]), np.arange(10)) + with pytest.warns(UserWarning, match="Only one estimated beat"): + metric(np.arange(10), np.array([10])) def test_continuity_edge_cases(): # There is some special-case logic for when there are few beats - assert np.allclose(mir_eval.beat.continuity( - np.array([6., 6.]), np.array([6., 7.])), 0.) - assert np.allclose(mir_eval.beat.continuity( - np.array([6., 6.]), np.array([6.5, 7.])), 0.) + assert np.allclose( + mir_eval.beat.continuity(np.array([6.0, 6.0]), np.array([6.0, 7.0])), 0.0 + ) + assert np.allclose( + mir_eval.beat.continuity(np.array([6.0, 6.0]), np.array([6.5, 7.0])), 0.0 + ) diff --git a/tests/test_chord.py b/tests/test_chord.py index 81d39402..0e70cf97 100644 --- a/tests/test_chord.py +++ b/tests/test_chord.py @@ -4,443 +4,501 @@ import mir_eval import numpy as np -import nose.tools -import warnings +import pytest import glob import json A_TOL = 1e-12 # Path to the fixture files -REF_GLOB = 'data/chord/ref*.lab' -EST_GLOB = 'data/chord/est*.lab' -SCORES_GLOB = 'data/chord/output*.json' - - -def __check_valid(function, parameters, result): - ''' Helper function for checking the output of a function ''' - assert function(*parameters) == result - - -def __check_exception(function, parameters, exception): - ''' Makes sure the provided function throws the provided - exception given the provided input ''' - nose.tools.assert_raises(exception, function, *parameters) - - -def test_pitch_class_to_semitone(): - valid_classes = ['Gbb', 'G', 'G#', 'Cb', 'B#'] - valid_semitones = [5, 7, 8, 11, 0] - - for pitch_class, semitone in zip(valid_classes, valid_semitones): - yield (__check_valid, mir_eval.chord.pitch_class_to_semitone, - (pitch_class,), semitone) - - invalid_classes = ['Cab', '#C', 'bG'] - - for pitch_class in invalid_classes: - yield (__check_exception, mir_eval.chord.pitch_class_to_semitone, - (pitch_class,), mir_eval.chord.InvalidChordException) - - -def test_scale_degree_to_semitone(): - valid_degrees = ['b7', '#3', '1', 'b1', '#7', 'bb5', '11', '#13'] - valid_semitones = [10, 5, 0, -1, 12, 5, 17, 22] - - for scale_degree, semitone in zip(valid_degrees, valid_semitones): - yield (__check_valid, mir_eval.chord.scale_degree_to_semitone, - (scale_degree,), semitone) - - invalid_degrees = ['7b', '4#', '77', '15'] - - for scale_degree in invalid_degrees: - yield (__check_exception, mir_eval.chord.scale_degree_to_semitone, - (scale_degree,), mir_eval.chord.InvalidChordException) - - -def test_scale_degree_to_bitmap(): - - def __check_bitmaps(function, parameters, result): - actual = function(*parameters) - assert np.all(actual == result), (actual, result) - - valid_degrees = ['3', '*3', 'b1', '9'] - valid_bitmaps = [[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]] - - for scale_degree, bitmap in zip(valid_degrees, valid_bitmaps): - yield (__check_bitmaps, mir_eval.chord.scale_degree_to_bitmap, - (scale_degree, True, 12), np.array(bitmap)) - - yield (__check_bitmaps, mir_eval.chord.scale_degree_to_bitmap, - ('9', False, 12), np.array([0] * 12)) - - yield (__check_bitmaps, mir_eval.chord.scale_degree_to_bitmap, - ('9', False, 15), - np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) - - -def test_validate_chord_label(): - valid_labels = ['C', 'Eb:min/5', 'A#:dim7', 'B:maj(*1,*5)/3', - 'A#:sus4', 'A:(9,11)'] +REF_GLOB = "data/chord/ref*.lab" +EST_GLOB = "data/chord/est*.lab" +SCORES_GLOB = "data/chord/output*.json" + +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) + +assert len(ref_files) == len(est_files) == len(sco_files) > 0 + +file_sets = list(zip(ref_files, est_files, sco_files)) + + +@pytest.fixture +def chord_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) + # Load in reference melody + ref_intervals, ref_labels = mir_eval.io.load_labeled_intervals(ref_f) + # Load in estimated melody + est_intervals, est_labels = mir_eval.io.load_labeled_intervals(est_f) + return ref_intervals, ref_labels, est_intervals, est_labels, expected_scores + + +@pytest.mark.parametrize( + "pitch, semitone", [("Gbb", 5), ("G", 7), ("G#", 8), ("Cb", 11), ("B#", 0)] +) +def test_pitch_class_to_semitone_valid(pitch, semitone): + assert mir_eval.chord.pitch_class_to_semitone(pitch) == semitone + + +@pytest.mark.parametrize("pitch", ["Cab", "#C", "bG"]) +@pytest.mark.xfail(raises=mir_eval.chord.InvalidChordException) +def test_pitch_class_to_semitone_fail(pitch): + mir_eval.chord.pitch_class_to_semitone(pitch) + + +@pytest.mark.parametrize( + "degree, semitone", + [ + ("b7", 10), + ("#3", 5), + ("1", 0), + ("b1", -1), + ("#7", 12), + ("bb5", 5), + ("11", 17), + ("#13", 22), + ], +) +def test_scale_degree_to_semitone(degree, semitone): + assert mir_eval.chord.scale_degree_to_semitone(degree) == semitone + + +@pytest.mark.parametrize("degree", ["7b", "4#", "77", "15"]) +@pytest.mark.xfail(raises=mir_eval.chord.InvalidChordException) +def test_scale_degree_to_semitone(degree): + mir_eval.chord.scale_degree_to_semitone(degree) + + +@pytest.mark.parametrize( + "degree, bitmap, modulo, length", + [ + ("3", [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], True, 12), + ("*3", [0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0], True, 12), + ("b1", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], True, 12), + ("9", [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], True, 12), + ("9", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], False, 12), + ("9", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], False, 15), + ], +) +def test_scale_degree_to_bitmap(degree, bitmap, modulo, length): + assert np.allclose( + mir_eval.chord.scale_degree_to_bitmap(degree, modulo=modulo, length=length), + bitmap, + ) + + +@pytest.mark.parametrize( + "label", ["C", "Eb:min/5", "A#:dim7", "B:maj(*1,*5)/3", "A#:sus4", "A:(9,11)"] +) +def test_validate_chord_label(label): # For valid labels, calling the function without an error = pass - for chord_label in valid_labels: - yield (mir_eval.chord.validate_chord_label, chord_label) - - invalid_labels = ["C::maj", "C//5", "C((4)", "C5))", - "C:maj(*3/3", "Cmaj*3/3)", 'asdf'] - - for chord_label in invalid_labels: - yield (__check_exception, mir_eval.chord.validate_chord_label, - (chord_label,), mir_eval.chord.InvalidChordException) - - -def test_split(): - labels = ['C', 'B:maj(*1,*3)/5', 'Ab:min/b3', 'N', 'G:(3)'] - splits = [['C', 'maj', set(), '1'], - ['B', 'maj', set(['*1', '*3']), '5'], - ['Ab', 'min', set(), 'b3'], - ['N', '', set(), ''], - ['G', '', set(['3']), '1']] - - for chord_label, split_chord in zip(labels, splits): - yield (__check_valid, mir_eval.chord.split, - (chord_label,), split_chord) - + mir_eval.chord.validate_chord_label(label) + + +@pytest.mark.parametrize( + "label", ["C::maj", "C//5", "C((4)", "C5))", "C:maj(*3/3", "Cmaj*3/3)", "asdf"] +) +@pytest.mark.xfail(raises=mir_eval.chord.InvalidChordException) +def test_validate_bad_chord_label(label): + mir_eval.chord.validate_chord_label(label) + + +@pytest.mark.parametrize( + "label, split", + [ + ("C", ["C", "maj", set(), "1"]), + ("B:maj(*1,*3)/5", ["B", "maj", {"*1", "*3"}, "5"]), + ("Ab:min/b3", ["Ab", "min", set(), "b3"]), + ("N", ["N", "", set(), ""]), + ("G:(3)", ["G", "", {"3"}, "1"]), + ], +) +def test_split(label, split): + assert mir_eval.chord.split(label) == split + + +@pytest.mark.parametrize( + "label, split", + [("C", ["C", "maj", set(), "1"]), ("C:minmaj7", ["C", "min", {"7"}, "1"])], +) +def test_split_extended(label, split): # Test with reducing extended chords - labels = ['C', 'C:minmaj7'] - splits = [['C', 'maj', set(), '1'], - ['C', 'min', set(['7']), '1']] - for chord_label, split_chord in zip(labels, splits): - yield (__check_valid, mir_eval.chord.split, - (chord_label, True), split_chord) + mir_eval.chord.split(label, reduce_extended_chords=True) == split + +@pytest.mark.xfail(raises=mir_eval.chord.InvalidChordException) +def test_split_fail(): # Test that an exception is raised when a chord with an omission but no # quality is supplied - yield (__check_exception, mir_eval.chord.split, - ('C(*5)',), mir_eval.chord.InvalidChordException) - - -def test_join(): - # Arguments are root, quality, extensions, bass - splits = [('F#', '', None, ''), - ('F#', 'hdim7', None, ''), - ('F#', '', ['*b3', '4'], ''), - ('F#', '', None, 'b7'), - ('F#', '', ['*b3', '4'], 'b7'), - ('F#', 'hdim7', None, 'b7'), - ('F#', 'hdim7', ['*b3', '4'], 'b7')] - labels = ['F#', 'F#:hdim7', 'F#:(*b3,4)', 'F#/b7', - 'F#:(*b3,4)/b7', 'F#:hdim7/b7', 'F#:hdim7(*b3,4)/b7'] - - for split_chord, chord_label in zip(splits, labels): - yield (__check_valid, mir_eval.chord.join, - split_chord, chord_label) - - -def test_rotate_bitmaps_to_roots(): - def __check_bitmaps(bitmaps, roots, expected_bitmaps): - ''' Helper function for checking bitmaps_to_roots ''' - ans = mir_eval.chord.rotate_bitmaps_to_roots(bitmaps, roots) - assert np.all(ans == expected_bitmaps) - - bitmaps = [ - [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], - [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], - [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0]] - roots = [0, 5, 11] - expected_bitmaps = [ - [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], - [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]] - - # The function can operate on many bitmaps/roots at a time - # but we should only test them one at a time. - for bitmap, root, expected_bitmap in zip(bitmaps, roots, expected_bitmaps): - yield (__check_bitmaps, [bitmap], [root], [expected_bitmap]) - - -def test_encode(): - def __check_encode(label, expected_root, expected_intervals, - expected_bass, reduce_extended_chords, - strict_bass_intervals): - ''' Helper function for checking encode ''' - root, intervals, bass = mir_eval.chord.encode( - label, reduce_extended_chords=reduce_extended_chords, - strict_bass_intervals=strict_bass_intervals) - assert root == expected_root, (root, expected_root) - assert np.all(intervals == expected_intervals), (intervals, - expected_intervals) - assert bass == expected_bass, (bass, expected_bass) - - labels = ['B:maj(*1,*3)/5', 'G:dim', 'C:(3)/3', 'A:9/b3'] - expected_roots = [11, 7, 0, 9] - expected_intervals = [[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], - [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0], - [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], - # Note that extended scale degrees are dropped. - [1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0]] - expected_bass = [7, 0, 4, 3] - - args = list(zip(labels, expected_roots, expected_intervals, expected_bass)) - for label, e_root, e_interval, e_bass in args: - yield (__check_encode, label, e_root, e_interval, e_bass, False, False) - + mir_eval.chord.split("C(*5)") + + +# Arguments are root, quality, extensions, bass +@pytest.mark.parametrize( + "label, split", + [ + ("F#", ("F#", "", None, "")), + ("F#:hdim7", ("F#", "hdim7", None, "")), + ("F#:(*b3,4)", ("F#", "", ["*b3", "4"], "")), + ("F#/b7", ("F#", "", None, "b7")), + ("F#:(*b3,4)/b7", ("F#", "", ["*b3", "4"], "b7")), + ("F#:hdim7/b7", ("F#", "hdim7", None, "b7")), + ("F#:hdim7(*b3,4)/b7", ("F#", "hdim7", ["*b3", "4"], "b7")), + ], +) +def test_join(label, split): + # Test is relying on implicit parameter ordering here: root, quality, extensions, bass + assert mir_eval.chord.join(*split) == label + + +@pytest.mark.parametrize( + "bitmap, root, expected_bitmap", + [ + ([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], 0, [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0]), + ([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], 5, [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]), + ( + [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + 11, + [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1], + ), + ], +) +def test_rotate_bitmaps_to_roots(bitmap, root, expected_bitmap): + ans = mir_eval.chord.rotate_bitmaps_to_roots([bitmap], [root]) + assert np.all(ans == [expected_bitmap]) + + +@pytest.mark.parametrize( + "label, e_root, e_interval, e_bass, reduce, strict", + [ + ("B:maj(*1,*3)/5", 11, [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 7, False, False), + ("G:dim", 7, [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0], 0, False, False), + ("C:(3)/3", 0, [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 4, False, False), + ("A:9/b3", 9, [1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0], 3, False, False), + ("G:dim(4)/6", 7, [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0], 9, False, False), + ("A:9", 9, [1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0], 0, True, False), + ], +) +def test_chord_encode(label, e_root, e_interval, e_bass, reduce, strict): + root, intervals, bass = mir_eval.chord.encode( + label, reduce_extended_chords=reduce, strict_bass_intervals=strict + ) + assert root == e_root, (root, e_root) + assert np.all(intervals == e_interval), (intervals, e_interval) + assert bass == e_bass, (bass, e_bass) + + +@pytest.mark.xfail(raises=mir_eval.chord.InvalidChordException) +def test_chord_encode_fail(): # Non-chord bass notes *must* be explicitly named as extensions when # strict_bass_intervals == True - yield (__check_exception, mir_eval.chord.encode, - ('G:dim(4)/6', False, True), mir_eval.chord.InvalidChordException) - - # Otherwise, we can cut a little slack. - yield (__check_encode, 'G:dim(4)/6', 7, - [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0], 9, - False, False) - - # Check that extended scale degrees are mapped back into pitch classes. - yield (__check_encode, 'A:9', 9, - [1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0], 0, - True, False) + mir_eval.chord.encode( + "G:dim(4)/6", reduce_extended_chords=False, strict_bass_intervals=True + ) def test_encode_many(): - def __check_encode_many(labels, expected_roots, expected_intervals, - expected_basses): - ''' Does all of the logic for checking encode_many ''' - roots, intervals, basses = mir_eval.chord.encode_many(labels) - assert np.all(roots == expected_roots) - assert np.all(intervals == expected_intervals) - assert np.all(basses == expected_basses) - - labels = ['B:maj(*1,*3)/5', - 'B:maj(*1,*3)/5', - 'N', - 'C:min', - 'C:min'] + """Does all of the logic for checking encode_many""" + labels = ["B:maj(*1,*3)/5", "B:maj(*1,*3)/5", "N", "C:min", "C:min"] expected_roots = [11, 11, -1, 0, 0] expected_intervals = [ [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0], - [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]] + [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0], + ] expected_basses = [7, 7, -1, 0, 0] - yield (__check_encode_many, labels, expected_roots, expected_intervals, - expected_basses) + roots, intervals, basses = mir_eval.chord.encode_many(labels) + assert np.all(roots == expected_roots) + assert np.all(intervals == expected_intervals) + assert np.all(basses == expected_basses) def __check_one_metric(metric, ref_label, est_label, score): - ''' Checks that a metric function produces score given ref_label and - est_label ''' + """Checks that a metric function produces score given ref_label and + est_label""" # We provide a dummy interval. We're just checking one pair # of labels at a time. assert metric([ref_label], [est_label]) == score def __check_not_comparable(metric, ref_label, est_label): - ''' Checks that ref_label is not comparable to est_label by metric ''' - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - # Try to produce the warning - score = mir_eval.chord.weighted_accuracy(metric([ref_label], - [est_label]), - np.array([1])) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == ("No reference chords were comparable " - "to estimated chords, returning 0.") + """Checks that ref_label is not comparable to est_label by metric""" + # Try to produce the warning + with pytest.warns( + UserWarning, + match="No reference chords were comparable to estimated chords, returning 0.", + ): + score = mir_eval.chord.weighted_accuracy( + metric([ref_label], [est_label]), np.array([1]) + ) # And confirm that the metric is 0 assert np.allclose(score, 0) + # TODO(ejhumphrey): Comparison functions lacking unit tests. # test_root() -def test_mirex(): - ref_labels = ['N', 'C:maj', 'C:maj', 'C:maj', 'C:min', 'C:maj', - 'C:maj', 'G:min', 'C:maj', 'C:min', 'C:min', - 'C:maj', 'F:maj', 'C:maj7', 'A:maj', 'A:maj'] - est_labels = ['N', 'N', 'C:aug', 'C:dim', 'C:dim', 'C:5', - 'C:sus4', 'G:sus2', 'G:maj', 'C:hdim7', 'C:min7', - 'C:maj6', 'F:min6', 'C:minmaj7', 'A:7', 'A:9'] - scores = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 1.0, - 1.0, 0.0, 1.0, 1.0, 1.0] - - for ref_label, est_label, score in zip(ref_labels, est_labels, scores): - yield (__check_one_metric, mir_eval.chord.mirex, - ref_label, est_label, score) - - ref_not_comparable = ['C:5', 'X'] - est_not_comparable = ['C:maj', 'N'] - - for ref_label, est_label in zip(ref_not_comparable, est_not_comparable): - yield (__check_not_comparable, mir_eval.chord.mirex, - ref_label, est_label) - - -def test_thirds(): - ref_labels = ['N', 'C:maj', 'C:maj', 'C:maj', 'C:min', - 'C:maj', 'G:min', 'C:maj', 'C:min', 'C:min', - 'C:maj', 'F:maj', 'C:maj', 'A:maj', 'A:maj'] - est_labels = ['N', 'N', 'C:aug', 'C:dim', 'C:dim', - 'C:sus4', 'G:sus2', 'G:maj', 'C:hdim7', 'C:min7', - 'C:maj6', 'F:min6', 'C:minmaj7', 'A:7', 'A:9'] - scores = [1.0, 0.0, 1.0, 0.0, 1.0, - 1.0, 0.0, 0.0, 1.0, 1.0, - 1.0, 0.0, 0.0, 1.0, 1.0] - - for ref_label, est_label, score in zip(ref_labels, est_labels, scores): - yield (__check_one_metric, mir_eval.chord.thirds, - ref_label, est_label, score) - - yield (__check_not_comparable, mir_eval.chord.thirds, 'X', 'N') - - -def test_thirds_inv(): - ref_labels = ['C:maj/5', 'G:min', 'C:maj', 'C:min/b3', 'C:min'] - est_labels = ['C:sus4/5', 'G:min/b3', 'C:maj/5', 'C:hdim7/b3', 'C:dim'] - scores = [1.0, 0.0, 0.0, 1.0, 1.0] - - for ref_label, est_label, score in zip(ref_labels, est_labels, scores): - yield (__check_one_metric, mir_eval.chord.thirds_inv, - ref_label, est_label, score) - - yield (__check_not_comparable, mir_eval.chord.thirds_inv, 'X', 'N') - - -def test_triads(): - ref_labels = ['C:min', 'C:maj', 'C:maj', 'C:min', 'C:maj', - 'C:maj', 'G:min', 'C:maj', 'C:min', 'C:min'] - est_labels = ['C:min7', 'C:7', 'C:aug', 'C:dim', 'C:sus2', - 'C:sus4', 'G:minmaj7', 'G:maj', 'C:hdim7', 'C:min6'] - scores = [1.0, 1.0, 0.0, 0.0, 0.0, - 0.0, 1.0, 0.0, 0.0, 1.0] - - for ref_label, est_label, score in zip(ref_labels, est_labels, scores): - yield (__check_one_metric, mir_eval.chord.triads, - ref_label, est_label, score) - - yield (__check_not_comparable, mir_eval.chord.triads, 'X', 'N') - - -def test_triads_inv(): - ref_labels = ['C:maj/5', 'G:min', 'C:maj', 'C:min/b3', 'C:min/b3'] - est_labels = ['C:maj7/5', 'G:min7/5', 'C:7/5', 'C:min6/b3', 'C:dim/b3'] - scores = [1.0, 0.0, 0.0, 1.0, 0.0] - - for ref_label, est_label, score in zip(ref_labels, est_labels, scores): - yield (__check_one_metric, mir_eval.chord.triads_inv, - ref_label, est_label, score) - - yield (__check_not_comparable, mir_eval.chord.triads_inv, 'X', 'N') - - -def test_tetrads(): - ref_labels = ['C:min', 'C:maj', 'C:7', 'C:maj7', 'C:sus2', - 'C:7/3', 'G:min', 'C:maj', 'C:min', 'C:min'] - est_labels = ['C:min7', 'C:maj6', 'C:9', 'C:maj7/5', 'C:sus2/2', - 'C:11/b7', 'G:sus2', 'G:maj', 'C:hdim7', 'C:minmaj7'] - scores = [0.0, 0.0, 1.0, 1.0, 1.0, - 1.0, 0.0, 0.0, 0.0, 0.0] - - for ref_label, est_label, score in zip(ref_labels, est_labels, scores): - yield (__check_one_metric, mir_eval.chord.tetrads, - ref_label, est_label, score) - - yield (__check_not_comparable, mir_eval.chord.tetrads, 'X', 'N') - - -def test_tetrads_inv(): - ref_labels = ['C:maj7/5', 'G:min', 'C:7/5', 'C:min/b3', 'C:min9'] - est_labels = ['C:maj7/3', 'G:min/b3', 'C:13/5', 'C:hdim7/b3', 'C:min7'] - scores = [0.0, 0.0, 1.0, 0.0, 1.0] - - for ref_label, est_label, score in zip(ref_labels, est_labels, scores): - yield (__check_one_metric, mir_eval.chord.tetrads_inv, - ref_label, est_label, score) - - yield (__check_not_comparable, mir_eval.chord.tetrads_inv, 'X', 'N') - - -def test_majmin(): - ref_labels = ['N', 'C:maj', 'C:maj', 'C:min', 'G:maj7'] - est_labels = ['N', 'N', 'C:aug', 'C:dim', 'G'] - scores = [1.0, 0.0, 0.0, 0.0, 1.0] - - for ref_label, est_label, score in zip(ref_labels, est_labels, scores): - yield (__check_one_metric, mir_eval.chord.majmin, - ref_label, est_label, score) - - ref_not_comparable = ['C:aug', 'X'] - est_not_comparable = ['C:maj', 'N'] - - for ref_label, est_label in zip(ref_not_comparable, est_not_comparable): - yield (__check_not_comparable, mir_eval.chord.majmin, - ref_label, est_label) - - -def test_majmin_inv(): - ref_labels = ['C:maj/5', 'G:min', 'C:maj/5', 'C:min7', - 'G:min/b3', 'C:maj7/5', 'C:7'] - est_labels = ['C:sus4/5', 'G:min/b3', 'C:maj/5', 'C:min', - 'G:min/b3', 'C:maj/5', 'C:maj'] - scores = [0.0, 0.0, 1.0, 1.0, - 1.0, 1.0, 1.0] - - for ref_label, est_label, score in zip(ref_labels, est_labels, scores): - yield (__check_one_metric, mir_eval.chord.majmin_inv, - ref_label, est_label, score) - - ref_not_comparable = ['C:hdim7/b3', 'C:maj/4', 'C:maj/2', 'X'] - est_not_comparable = ['C:min/b3', 'C:maj/4', 'C:sus2/2', 'N'] - - for ref_label, est_label in zip(ref_not_comparable, est_not_comparable): - yield (__check_not_comparable, mir_eval.chord.majmin_inv, - ref_label, est_label) - - -def test_sevenths(): - ref_labels = ['C:min', 'C:maj', 'C:7', 'C:maj7', - 'C:7/3', 'G:min', 'C:maj', 'C:7'] - est_labels = ['C:min7', 'C:maj6', 'C:9', 'C:maj7/5', - 'C:11/b7', 'G:sus2', 'G:maj', 'C:maj7'] - scores = [0.0, 0.0, 1.0, 1.0, - 1.0, 0.0, 0.0, 0.0] - - for ref_label, est_label, score in zip(ref_labels, est_labels, scores): - yield (__check_one_metric, mir_eval.chord.sevenths, - ref_label, est_label, score) - - ref_not_comparable = ['C:sus2', 'C:hdim7', 'X'] - est_not_comparable = ['C:sus2/2', 'C:hdim7', 'N'] - for ref_label, est_label in zip(ref_not_comparable, est_not_comparable): - yield (__check_not_comparable, mir_eval.chord.sevenths, - ref_label, est_label) - - -def test_sevenths_inv(): - ref_labels = ['C:maj7/5', 'G:min', 'C:7/5', 'C:min7/b7'] - est_labels = ['C:maj7/3', 'G:min/b3', 'C:13/5', 'C:min7/b7'] - scores = [0.0, 0.0, 1.0, 1.0] - - for ref_label, est_label, score in zip(ref_labels, est_labels, scores): - yield (__check_one_metric, mir_eval.chord.sevenths_inv, - ref_label, est_label, score) - - ref_not_comparable = ['C:dim7/b3', 'X'] - est_not_comparable = ['C:dim7/b3', 'N'] - for ref_label, est_label in zip(ref_not_comparable, est_not_comparable): - yield (__check_not_comparable, mir_eval.chord.sevenths_inv, - ref_label, est_label) +@pytest.mark.parametrize( + "ref_label, est_label, score", + [ + ("N", "N", 1.0), + ("C:maj", "N", 0.0), + ("C:maj", "C:aug", 0.0), + ("C:maj", "C:dim", 0.0), + ("C:min", "C:dim", 0.0), + ("C:maj", "C:5", 0.0), + ("C:maj", "C:sus4", 0.0), + ("G:min", "G:sus2", 0.0), + ("C:maj", "G:maj", 0.0), + ("C:min", "C:hdim7", 0.0), + ("C:min", "C:min7", 1.0), + ("C:maj", "C:maj6", 1.0), + ("F:maj", "F:min6", 0.0), + ("C:maj7", "C:minmaj7", 1.0), + ("A:maj", "A:7", 1.0), + ("A:maj", "A:9", 1.0), + ], +) +def test_mirex(ref_label, est_label, score): + __check_one_metric(mir_eval.chord.mirex, ref_label, est_label, score) + + +@pytest.mark.parametrize("ref, est", [("C:5", "C:maj"), ("X", "N")]) +def test_mirex_nocomp(ref, est): + __check_not_comparable(mir_eval.chord.mirex, ref, est) + + +@pytest.mark.parametrize( + "ref_label, est_label, score", + [ + ("N", "N", 1.0), + ("C:maj", "N", 0.0), + ("C:maj", "C:aug", 1.0), + ("C:maj", "C:dim", 0.0), + ("C:min", "C:dim", 1.0), + ("C:maj", "C:sus4", 1.0), + ("G:min", "G:sus2", 0.0), + ("C:maj", "G:maj", 0.0), + ("C:min", "C:hdim7", 1.0), + ("C:min", "C:min7", 1.0), + ("C:maj", "C:maj6", 1.0), + ("F:maj", "F:min6", 0.0), + ("C:maj", "C:minmaj7", 0.0), + ("A:maj", "A:7", 1.0), + ("A:maj", "A:9", 1.0), + ], +) +def test_thirds(ref_label, est_label, score): + __check_one_metric(mir_eval.chord.thirds, ref_label, est_label, score) + + +def test_thirds_nocomp(): + __check_not_comparable(mir_eval.chord.thirds, "X", "N") + + +@pytest.mark.parametrize( + "ref_label, est_label, score", + [ + ("C:maj/5", "C:sus4/5", 1.0), + ("G:min", "G:min/b3", 0.0), + ("C:maj", "C:maj/5", 0.0), + ("C:min/b3", "C:hdim7/b3", 1.0), + ("C:min", "C:dim", 1.0), + ], +) +def test_thirds_inv(ref_label, est_label, score): + __check_one_metric(mir_eval.chord.thirds_inv, ref_label, est_label, score) + + +def test_thirds_inv_nocomp(): + __check_not_comparable(mir_eval.chord.thirds_inv, "X", "N") + + +@pytest.mark.parametrize( + "ref_label, est_label, score", + [ + ("C:min", "C:min7", 1.0), + ("C:maj", "C:7", 1.0), + ("C:maj", "C:aug", 0.0), + ("C:min", "C:dim", 0.0), + ("C:maj", "C:sus2", 0.0), + ("C:maj", "C:sus4", 0.0), + ("G:min", "G:minmaj7", 1.0), + ("C:maj", "G:maj", 0.0), + ("C:min", "C:hdim7", 0.0), + ("C:min", "C:min6", 1.0), + ], +) +def test_triads(ref_label, est_label, score): + __check_one_metric(mir_eval.chord.triads, ref_label, est_label, score) + + +def test_triads_nocomp(): + __check_not_comparable(mir_eval.chord.triads, "X", "N") + + +@pytest.mark.parametrize( + "ref_label, est_label, score", + [ + ("C:maj/5", "C:maj7/5", 1.0), + ("G:min", "G:min7/5", 0.0), + ("C:maj", "C:7/5", 0.0), + ("C:min/b3", "C:min6/b3", 1.0), + ("C:min/b3", "C:dim/b3", 0.0), + ], +) +def test_triads_inv(ref_label, est_label, score): + __check_one_metric(mir_eval.chord.triads_inv, ref_label, est_label, score) + + +def test_triads_inv_nocomp(): + __check_not_comparable(mir_eval.chord.triads_inv, "X", "N") + + +@pytest.mark.parametrize( + "ref_label, est_label, score", + [ + ("C:min", "C:min7", 0.0), + ("C:maj", "C:maj6", 0.0), + ("C:7", "C:9", 1.0), + ("C:maj7", "C:maj7/5", 1.0), + ("C:sus2", "C:sus2/2", 1.0), + ("C:7/3", "C:11/b7", 1.0), + ("G:min", "G:sus2", 0.0), + ("C:maj", "G:maj", 0.0), + ("C:min", "C:hdim7", 0.0), + ("C:min", "C:minmaj7", 0.0), + ], +) +def test_tetrads(ref_label, est_label, score): + __check_one_metric(mir_eval.chord.tetrads, ref_label, est_label, score) + + +def test_tetrads_nocomp(): + __check_not_comparable(mir_eval.chord.tetrads, "X", "N") + + +@pytest.mark.parametrize( + "ref_label, est_label, score", + [ + ("C:maj7/5", "C:maj7/3", 0.0), + ("G:min", "G:min/b3", 0.0), + ("C:7/5", "C:13/5", 1.0), + ("C:min/b3", "C:hdim7/b3", 0.0), + ("C:min9", "C:min7", 1.0), + ], +) +def test_tetrads_inv(ref_label, est_label, score): + __check_one_metric(mir_eval.chord.tetrads_inv, ref_label, est_label, score) + + +def test_tetrads_inv_nocomp(): + __check_not_comparable(mir_eval.chord.tetrads_inv, "X", "N") + + +@pytest.mark.parametrize( + "ref_label, est_label, score", + [ + ("N", "N", 1.0), + ("C:maj", "N", 0.0), + ("C:maj", "C:aug", 0.0), + ("C:min", "C:dim", 0.0), + ("G:maj7", "G", 1.0), + ], +) +def test_majmin(ref_label, est_label, score): + __check_one_metric(mir_eval.chord.majmin, ref_label, est_label, score) + + +@pytest.mark.parametrize("ref_label, est_label", [("C:aug", "C:maj"), ("X", "N")]) +def test_majmin_nocomp(ref_label, est_label): + __check_not_comparable(mir_eval.chord.majmin, ref_label, est_label) + + +@pytest.mark.parametrize( + "ref_label, est_label, score", + [ + ("C:maj/5", "C:sus4/5", 0.0), + ("G:min", "G:min/b3", 0.0), + ("C:maj/5", "C:maj/5", 1.0), + ("C:min7", "C:min", 1.0), + ("G:min/b3", "G:min/b3", 1.0), + ("C:maj7/5", "C:maj/5", 1.0), + ("C:7", "C:maj", 1.0), + ], +) +def test_majmin_inv(ref_label, est_label, score): + __check_one_metric(mir_eval.chord.majmin_inv, ref_label, est_label, score) + + +@pytest.mark.parametrize( + "ref_label, est_label", + [ + ("C:hdim7/b3", "C:min/b3"), + ("C:maj/4", "C:maj/4"), + ("C:maj/2", "C:sus2/2"), + ("X", "N"), + ], +) +def test_majmin_inv_nocomp(ref_label, est_label): + __check_not_comparable(mir_eval.chord.majmin_inv, ref_label, est_label) + + +@pytest.mark.parametrize( + "ref_label, est_label, score", + [ + ("C:min", "C:min7", 0.0), + ("C:maj", "C:maj6", 0.0), + ("C:7", "C:9", 1.0), + ("C:maj7", "C:maj7/5", 1.0), + ("C:7/3", "C:11/b7", 1.0), + ("G:min", "G:sus2", 0.0), + ("C:maj", "G:maj", 0.0), + ("C:7", "C:maj7", 0.0), + ], +) +def test_sevenths(ref_label, est_label, score): + __check_one_metric(mir_eval.chord.sevenths, ref_label, est_label, score) + + +@pytest.mark.parametrize( + "ref_label, est_label", [("C:sus2", "C:sus2/2"), ("C:hdim7", "C:hdim7"), ("X", "N")] +) +def test_sevenths_nocomp(ref_label, est_label): + __check_not_comparable(mir_eval.chord.sevenths, ref_label, est_label) + + +@pytest.mark.parametrize( + "ref_label, est_label, score", + [ + ("C:maj7/5", "C:maj7/3", 0.0), + ("G:min", "G:min/b3", 0.0), + ("C:7/5", "C:13/5", 1.0), + ("C:min7/b7", "C:min7/b7", 1.0), + ], +) +def test_sevenths_inv(ref_label, est_label, score): + __check_one_metric(mir_eval.chord.sevenths_inv, ref_label, est_label, score) + + +@pytest.mark.parametrize( + "ref_label, est_label", [("C:dim7/b3", "C:dim7/b3"), ("X", "N")] +) +def test_sevenths_inv_nocomp(ref_label, est_label): + __check_not_comparable(mir_eval.chord.sevenths_inv, ref_label, est_label) def test_directional_hamming_distance(): - ref_ivs = np.array([[0., 1.], [1., 2.], [2., 3.]]) - est_ivs = np.array([[0., 0.9], [0.9, 1.8], [1.8, 2.5]]) - dhd_ref_to_est = (0.1 + 0.2 + 0.5) / 3. + ref_ivs = np.array([[0.0, 1.0], [1.0, 2.0], [2.0, 3.0]]) + est_ivs = np.array([[0.0, 0.9], [0.9, 1.8], [1.8, 2.5]]) + dhd_ref_to_est = (0.1 + 0.2 + 0.5) / 3.0 dhd_est_to_ref = (0.0 + 0.1 + 0.2) / 2.5 dhd = mir_eval.chord.directional_hamming_distance @@ -449,24 +507,26 @@ def test_directional_hamming_distance(): assert np.allclose(0, dhd(ref_ivs, ref_ivs)) assert np.allclose(0, dhd(est_ivs, est_ivs)) - ivs_overlap_all = np.array([[0., 1.], [0.9, 2.]]) - ivs_overlap_one = np.array([[0., 1.], [0.9, 2.], [2., 3.]]) - nose.tools.assert_raises(ValueError, dhd, ivs_overlap_all, est_ivs) - nose.tools.assert_raises(ValueError, dhd, ivs_overlap_one, est_ivs) + ivs_overlap_all = np.array([[0.0, 1.0], [0.9, 2.0]]) + ivs_overlap_one = np.array([[0.0, 1.0], [0.9, 2.0], [2.0, 3.0]]) + with pytest.raises(ValueError): + dhd(ivs_overlap_all, est_ivs) + with pytest.raises(ValueError): + dhd(ivs_overlap_one, est_ivs) def test_segmentation_functions(): - ref_ivs = np.array([[0., 2.], [2., 2.5], [2.5, 3.2]]) - est_ivs = np.array([[0., 3.], [3., 3.5]]) - true_oseg = 1. - 0.2 / 3.2 - true_useg = 1. - (1. + 0.2) / 3.5 + ref_ivs = np.array([[0.0, 2.0], [2.0, 2.5], [2.5, 3.2]]) + est_ivs = np.array([[0.0, 3.0], [3.0, 3.5]]) + true_oseg = 1.0 - 0.2 / 3.2 + true_useg = 1.0 - (1.0 + 0.2) / 3.5 true_seg = min(true_oseg, true_useg) assert np.allclose(true_oseg, mir_eval.chord.overseg(ref_ivs, est_ivs)) assert np.allclose(true_useg, mir_eval.chord.underseg(ref_ivs, est_ivs)) assert np.allclose(true_seg, mir_eval.chord.seg(ref_ivs, est_ivs)) - ref_ivs = np.array([[0., 2.], [2., 2.5], [2.5, 3.2]]) - est_ivs = np.array([[0., 2.], [2., 2.5], [2.5, 3.2]]) + ref_ivs = np.array([[0.0, 2.0], [2.0, 2.5], [2.5, 3.2]]) + est_ivs = np.array([[0.0, 2.0], [2.0, 2.5], [2.5, 3.2]]) true_oseg = 1.0 true_useg = 1.0 true_seg = 1.0 @@ -474,8 +534,8 @@ def test_segmentation_functions(): assert np.allclose(true_useg, mir_eval.chord.underseg(ref_ivs, est_ivs)) assert np.allclose(true_seg, mir_eval.chord.seg(ref_ivs, est_ivs)) - ref_ivs = np.array([[0., 2.], [2., 2.5], [2.5, 3.2]]) - est_ivs = np.array([[0., 3.2]]) + ref_ivs = np.array([[0.0, 2.0], [2.0, 2.5], [2.5, 3.2]]) + est_ivs = np.array([[0.0, 3.2]]) true_oseg = 1.0 true_useg = 1 - 1.2 / 3.2 true_seg = min(true_oseg, true_useg) @@ -483,7 +543,7 @@ def test_segmentation_functions(): assert np.allclose(true_useg, mir_eval.chord.underseg(ref_ivs, est_ivs)) assert np.allclose(true_seg, mir_eval.chord.seg(ref_ivs, est_ivs)) - ref_ivs = np.array([[0., 2.], [2., 2.5], [2.5, 3.2]]) + ref_ivs = np.array([[0.0, 2.0], [2.0, 2.5], [2.5, 3.2]]) est_ivs = np.array([[3.2, 3.5]]) true_oseg = 1.0 true_useg = 1.0 @@ -494,34 +554,35 @@ def test_segmentation_functions(): def test_merge_chord_intervals(): - intervals = np.array([[0., 1.], [1., 2.], [2., 3], [3., 4.], [4., 5.]]) - labels = ['C:maj', 'C:(1,3,5)', 'A:maj', 'A:maj7', 'A:maj7/3'] - assert np.allclose(np.array([[0., 2.], [2., 3], [3., 4.], [4., 5.]]), - mir_eval.chord.merge_chord_intervals(intervals, labels)) + intervals = np.array([[0.0, 1.0], [1.0, 2.0], [2.0, 3], [3.0, 4.0], [4.0, 5.0]]) + labels = ["C:maj", "C:(1,3,5)", "A:maj", "A:maj7", "A:maj7/3"] + assert np.allclose( + np.array([[0.0, 2.0], [2.0, 3], [3.0, 4.0], [4.0, 5.0]]), + mir_eval.chord.merge_chord_intervals(intervals, labels), + ) def test_weighted_accuracy(): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - # First, test for a warning on empty beats - score = mir_eval.chord.weighted_accuracy(np.array([1, 0, 1]), - np.array([0, 0, 0])) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == 'No nonzero weights, returning 0' + # First, test for a warning on empty beats + with pytest.warns(UserWarning, match="No nonzero weights, returning 0"): + score = mir_eval.chord.weighted_accuracy( + np.array([1, 0, 1]), np.array([0, 0, 0]) + ) # And that the metric is 0 assert np.allclose(score, 0) # len(comparisons) must equal len(weights) comparisons = np.array([1, 0, 1]) weights = np.array([1, 1]) - nose.tools.assert_raises(ValueError, mir_eval.chord.weighted_accuracy, - comparisons, weights) + + with pytest.raises(ValueError): + mir_eval.chord.weighted_accuracy(comparisons, weights) + # Weights must all be positive comparisons = np.array([1, 1]) weights = np.array([-1, -1]) - nose.tools.assert_raises(ValueError, mir_eval.chord.weighted_accuracy, - comparisons, weights) + with pytest.raises(ValueError): + mir_eval.chord.weighted_accuracy(comparisons, weights) # Make sure accuracy = 1 and 0 when all comparisons are True and False resp comparisons = np.array([1, 1, 1]) @@ -533,61 +594,47 @@ def test_weighted_accuracy(): assert np.allclose(score, 0) -def __check_score(sco_f, metric, score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) +@pytest.mark.parametrize("chord_data", file_sets, indirect=True) +def test_chord_functions(chord_data): + ref_intervals, ref_labels, est_intervals, est_labels, expected_scores = chord_data - -def test_chord_functions(): - # Load in all files in the same order - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - assert len(ref_files) == len(est_files) == len(sco_files) > 0 - - # Regression tests - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_scores = json.load(f) - # Load in an example beat annotation - ref_intervals, ref_labels = mir_eval.io.load_labeled_intervals(ref_f) - # Load in an example beat tracker output - est_intervals, est_labels = mir_eval.io.load_labeled_intervals(est_f) - # Compute scores - scores = mir_eval.chord.evaluate(ref_intervals, ref_labels, - est_intervals, est_labels) - # Compare them - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, scores[metric], - expected_scores[metric]) + # Compute scores + scores = mir_eval.chord.evaluate( + ref_intervals, ref_labels, est_intervals, est_labels + ) + # Compare them + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) def test_quality_to_bitmap(): - # Test simple case - assert np.all(mir_eval.chord.quality_to_bitmap('maj') == np.array( - [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0])) + assert np.all( + mir_eval.chord.quality_to_bitmap("maj") + == np.array([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0]) + ) + +@pytest.mark.xfail(raises=mir_eval.chord.InvalidChordException) +@pytest.mark.parametrize("quality", ["maj5", "2", "#7"]) +def test_quality_to_bitmap_fail(quality): # Check exceptions for qualities not in the QUALITIES list - invalid_qualities = ['maj5', '2', '#7'] - for quality in invalid_qualities: - yield (__check_exception, mir_eval.chord.quality_to_bitmap, - (quality,), mir_eval.chord.InvalidChordException) + mir_eval.chord.quality_to_bitmap(quality) def test_validate(): # Test that the validate function raises the appropriate errors and # warnings - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + with pytest.warns() as w: # First, test for warnings on empty labels mir_eval.chord.validate([], []) - assert len(w) == 2 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == "Estimated labels are empty" - assert issubclass(w[-2].category, UserWarning) - assert str(w[-2].message) == "Reference labels are empty" - # Test that error is thrown on different-length labels - nose.tools.assert_raises( - ValueError, mir_eval.chord.validate, [], ['C']) + assert len(w) == 2 + assert issubclass(w[-1].category, UserWarning) + assert str(w[-1].message) == "Estimated labels are empty" + assert issubclass(w[-2].category, UserWarning) + assert str(w[-2].message) == "Reference labels are empty" + + # Test that error is thrown on different-length labels + with pytest.raises(ValueError): + mir_eval.chord.validate([], ["C"]) diff --git a/tests/test_display.py b/tests/test_display.py index f714f3ae..748bf476 100644 --- a/tests/test_display.py +++ b/tests/test_display.py @@ -1,19 +1,17 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -'''Unit tests for the display module''' +"""Unit tests for the display module""" # For testing purposes, clobber the rcfile import matplotlib -matplotlib.use('Agg') # nopep8 +matplotlib.use("Agg") # nopep8 import matplotlib.pyplot as plt import numpy as np import pytest -from nose.tools import raises - # We'll make a decorator to handle style contexts from decorator import decorator @@ -26,20 +24,24 @@ from mir_eval.io import load_wav +pytestmark = pytest.mark.skip( + reason="disabling display tests until after merger of #370" +) + + @decorator def styled(f, *args, **kwargs): matplotlib.rcdefaults() return f(*args, **kwargs) -@pytest.mark.mpl_image_compare(baseline_images=['segment'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["segment"], extensions=["png"]) @styled def test_display_segment(): - plt.figure() # Load some segment data - intervals, labels = load_labeled_intervals('data/segment/ref00.lab') + intervals, labels = load_labeled_intervals("data/segment/ref00.lab") # Plot the segments with no labels mir_eval.display.segments(intervals, labels, text=False) @@ -48,329 +50,324 @@ def test_display_segment(): plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['segment_text'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["segment_text"], extensions=["png"]) @styled def test_display_segment_text(): plt.figure() # Load some segment data - intervals, labels = load_labeled_intervals('data/segment/ref00.lab') + intervals, labels = load_labeled_intervals("data/segment/ref00.lab") # Plot the segments with no labels mir_eval.display.segments(intervals, labels, text=True) -@pytest.mark.mpl_image_compare(baseline_images=['labeled_intervals'], extensions=['png']) +@pytest.mark.mpl_image_compare( + baseline_images=["labeled_intervals"], extensions=["png"] +) @styled def test_display_labeled_intervals(): - plt.figure() # Load some chord data - intervals, labels = load_labeled_intervals('data/chord/ref01.lab') + intervals, labels = load_labeled_intervals("data/chord/ref01.lab") # Plot the chords with nothing fancy mir_eval.display.labeled_intervals(intervals, labels) -@pytest.mark.mpl_image_compare(baseline_images=['labeled_intervals_noextend'], - extensions=['png']) +@pytest.mark.mpl_image_compare( + baseline_images=["labeled_intervals_noextend"], extensions=["png"] +) @styled def test_display_labeled_intervals_noextend(): - plt.figure() # Load some chord data - intervals, labels = load_labeled_intervals('data/chord/ref01.lab') + intervals, labels = load_labeled_intervals("data/chord/ref01.lab") # Plot the chords with nothing fancy ax = plt.axes() ax.set_yticklabels([]) - mir_eval.display.labeled_intervals(intervals, labels, - label_set=[], - extend_labels=False, - ax=ax) + mir_eval.display.labeled_intervals( + intervals, labels, label_set=[], extend_labels=False, ax=ax + ) -@pytest.mark.mpl_image_compare(baseline_images=['labeled_intervals_compare'], - extensions=['png']) +@pytest.mark.mpl_image_compare( + baseline_images=["labeled_intervals_compare"], extensions=["png"] +) @styled def test_display_labeled_intervals_compare(): - plt.figure() # Load some chord data - ref_int, ref_labels = load_labeled_intervals('data/chord/ref01.lab') - est_int, est_labels = load_labeled_intervals('data/chord/est01.lab') + ref_int, ref_labels = load_labeled_intervals("data/chord/ref01.lab") + est_int, est_labels = load_labeled_intervals("data/chord/est01.lab") # Plot reference and estimates using label set extension - mir_eval.display.labeled_intervals(ref_int, ref_labels, - alpha=0.5, label='Reference') - mir_eval.display.labeled_intervals(est_int, est_labels, - alpha=0.5, label='Estimate') + mir_eval.display.labeled_intervals( + ref_int, ref_labels, alpha=0.5, label="Reference" + ) + mir_eval.display.labeled_intervals(est_int, est_labels, alpha=0.5, label="Estimate") plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['labeled_intervals_compare_noextend'], - extensions=['png']) +@pytest.mark.mpl_image_compare( + baseline_images=["labeled_intervals_compare_noextend"], extensions=["png"] +) @styled def test_display_labeled_intervals_compare_noextend(): - plt.figure() # Load some chord data - ref_int, ref_labels = load_labeled_intervals('data/chord/ref01.lab') - est_int, est_labels = load_labeled_intervals('data/chord/est01.lab') + ref_int, ref_labels = load_labeled_intervals("data/chord/ref01.lab") + est_int, est_labels = load_labeled_intervals("data/chord/est01.lab") # Plot reference and estimate, but only use the reference labels - mir_eval.display.labeled_intervals(ref_int, ref_labels, - alpha=0.5, label='Reference') - mir_eval.display.labeled_intervals(est_int, est_labels, - extend_labels=False, - alpha=0.5, label='Estimate') + mir_eval.display.labeled_intervals( + ref_int, ref_labels, alpha=0.5, label="Reference" + ) + mir_eval.display.labeled_intervals( + est_int, est_labels, extend_labels=False, alpha=0.5, label="Estimate" + ) plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['labeled_intervals_compare_common'], - extensions=['png']) +@pytest.mark.mpl_image_compare( + baseline_images=["labeled_intervals_compare_common"], extensions=["png"] +) @styled def test_display_labeled_intervals_compare_common(): - plt.figure() # Load some chord data - ref_int, ref_labels = load_labeled_intervals('data/chord/ref01.lab') - est_int, est_labels = load_labeled_intervals('data/chord/est01.lab') + ref_int, ref_labels = load_labeled_intervals("data/chord/ref01.lab") + est_int, est_labels = load_labeled_intervals("data/chord/est01.lab") label_set = list(sorted(set(ref_labels) | set(est_labels))) # Plot reference and estimate with a common label set - mir_eval.display.labeled_intervals(ref_int, ref_labels, - label_set=label_set, - alpha=0.5, label='Reference') - mir_eval.display.labeled_intervals(est_int, est_labels, - label_set=label_set, - alpha=0.5, label='Estimate') + mir_eval.display.labeled_intervals( + ref_int, ref_labels, label_set=label_set, alpha=0.5, label="Reference" + ) + mir_eval.display.labeled_intervals( + est_int, est_labels, label_set=label_set, alpha=0.5, label="Estimate" + ) plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['hierarchy_nolabel'], extensions=['png']) +@pytest.mark.mpl_image_compare( + baseline_images=["hierarchy_nolabel"], extensions=["png"] +) @styled def test_display_hierarchy_nolabel(): - plt.figure() # Load some chord data - int0, lab0 = load_labeled_intervals('data/hierarchy/ref00.lab') - int1, lab1 = load_labeled_intervals('data/hierarchy/ref01.lab') + int0, lab0 = load_labeled_intervals("data/hierarchy/ref00.lab") + int1, lab1 = load_labeled_intervals("data/hierarchy/ref01.lab") # Plot reference and estimate with a common label set - mir_eval.display.hierarchy([int0, int1], - [lab0, lab1]) + mir_eval.display.hierarchy([int0, int1], [lab0, lab1]) plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['hierarchy_label'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["hierarchy_label"], extensions=["png"]) @styled def test_display_hierarchy_label(): - plt.figure() # Load some chord data - int0, lab0 = load_labeled_intervals('data/hierarchy/ref00.lab') - int1, lab1 = load_labeled_intervals('data/hierarchy/ref01.lab') + int0, lab0 = load_labeled_intervals("data/hierarchy/ref00.lab") + int1, lab1 = load_labeled_intervals("data/hierarchy/ref01.lab") # Plot reference and estimate with a common label set - mir_eval.display.hierarchy([int0, int1], - [lab0, lab1], - levels=['Large', 'Small']) + mir_eval.display.hierarchy([int0, int1], [lab0, lab1], levels=["Large", "Small"]) plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['pitch_hz'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["pitch_hz"], extensions=["png"]) @styled def test_pitch_hz(): plt.figure() - ref_times, ref_freqs = load_labeled_events('data/melody/ref00.txt') - est_times, est_freqs = load_labeled_events('data/melody/est00.txt') + ref_times, ref_freqs = load_labeled_events("data/melody/ref00.txt") + est_times, est_freqs = load_labeled_events("data/melody/est00.txt") # Plot pitches on a Hz scale - mir_eval.display.pitch(ref_times, ref_freqs, unvoiced=True, - label='Reference') - mir_eval.display.pitch(est_times, est_freqs, unvoiced=True, - label='Estimate') + mir_eval.display.pitch(ref_times, ref_freqs, unvoiced=True, label="Reference") + mir_eval.display.pitch(est_times, est_freqs, unvoiced=True, label="Estimate") plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['pitch_midi'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["pitch_midi"], extensions=["png"]) @styled def test_pitch_midi(): plt.figure() - times, freqs = load_labeled_events('data/melody/ref00.txt') + times, freqs = load_labeled_events("data/melody/ref00.txt") # Plot pitches on a midi scale with note tickers mir_eval.display.pitch(times, freqs, midi=True) mir_eval.display.ticker_notes() -@pytest.mark.mpl_image_compare(baseline_images=['pitch_midi_hz'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["pitch_midi_hz"], extensions=["png"]) @styled def test_pitch_midi_hz(): plt.figure() - times, freqs = load_labeled_events('data/melody/ref00.txt') + times, freqs = load_labeled_events("data/melody/ref00.txt") # Plot pitches on a midi scale with note tickers mir_eval.display.pitch(times, freqs, midi=True) mir_eval.display.ticker_pitch() -@pytest.mark.mpl_image_compare(baseline_images=['multipitch_hz_unvoiced'], - extensions=['png']) +@pytest.mark.mpl_image_compare( + baseline_images=["multipitch_hz_unvoiced"], extensions=["png"] +) @styled def test_multipitch_hz_unvoiced(): plt.figure() - times, pitches = load_ragged_time_series('data/multipitch/est01.txt') + times, pitches = load_ragged_time_series("data/multipitch/est01.txt") # Plot pitches on a midi scale with note tickers mir_eval.display.multipitch(times, pitches, midi=False, unvoiced=True) -@pytest.mark.mpl_image_compare(baseline_images=['multipitch_hz_voiced'], extensions=['png']) +@pytest.mark.mpl_image_compare( + baseline_images=["multipitch_hz_voiced"], extensions=["png"] +) @styled def test_multipitch_hz_voiced(): plt.figure() - times, pitches = load_ragged_time_series('data/multipitch/est01.txt') + times, pitches = load_ragged_time_series("data/multipitch/est01.txt") mir_eval.display.multipitch(times, pitches, midi=False, unvoiced=False) -@pytest.mark.mpl_image_compare(baseline_images=['multipitch_midi'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["multipitch_midi"], extensions=["png"]) @styled def test_multipitch_midi(): plt.figure() - ref_t, ref_p = load_ragged_time_series('data/multipitch/ref01.txt') - est_t, est_p = load_ragged_time_series('data/multipitch/est01.txt') + ref_t, ref_p = load_ragged_time_series("data/multipitch/ref01.txt") + est_t, est_p = load_ragged_time_series("data/multipitch/est01.txt") # Plot pitches on a midi scale with note tickers - mir_eval.display.multipitch(ref_t, ref_p, midi=True, - alpha=0.5, label='Reference') - mir_eval.display.multipitch(est_t, est_p, midi=True, - alpha=0.5, label='Estimate') + mir_eval.display.multipitch(ref_t, ref_p, midi=True, alpha=0.5, label="Reference") + mir_eval.display.multipitch(est_t, est_p, midi=True, alpha=0.5, label="Estimate") plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['piano_roll'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["piano_roll"], extensions=["png"]) @styled def test_pianoroll(): plt.figure() - ref_t, ref_p = load_valued_intervals('data/transcription/ref04.txt') - est_t, est_p = load_valued_intervals('data/transcription/est04.txt') + ref_t, ref_p = load_valued_intervals("data/transcription/ref04.txt") + est_t, est_p = load_valued_intervals("data/transcription/est04.txt") - mir_eval.display.piano_roll(ref_t, ref_p, - label='Reference', alpha=0.5) - mir_eval.display.piano_roll(est_t, est_p, - label='Estimate', alpha=0.5, facecolor='r') + mir_eval.display.piano_roll(ref_t, ref_p, label="Reference", alpha=0.5) + mir_eval.display.piano_roll( + est_t, est_p, label="Estimate", alpha=0.5, facecolor="r" + ) plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['piano_roll_midi'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["piano_roll_midi"], extensions=["png"]) @styled def test_pianoroll_midi(): plt.figure() - ref_t, ref_p = load_valued_intervals('data/transcription/ref04.txt') - est_t, est_p = load_valued_intervals('data/transcription/est04.txt') + ref_t, ref_p = load_valued_intervals("data/transcription/ref04.txt") + est_t, est_p = load_valued_intervals("data/transcription/est04.txt") ref_midi = mir_eval.util.hz_to_midi(ref_p) est_midi = mir_eval.util.hz_to_midi(est_p) - mir_eval.display.piano_roll(ref_t, midi=ref_midi, - label='Reference', alpha=0.5) - mir_eval.display.piano_roll(est_t, midi=est_midi, - label='Estimate', alpha=0.5, facecolor='r') + mir_eval.display.piano_roll(ref_t, midi=ref_midi, label="Reference", alpha=0.5) + mir_eval.display.piano_roll( + est_t, midi=est_midi, label="Estimate", alpha=0.5, facecolor="r" + ) plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['ticker_midi_zoom'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["ticker_midi_zoom"], extensions=["png"]) @styled def test_ticker_midi_zoom(): - plt.figure() plt.plot(np.arange(3)) mir_eval.display.ticker_notes() -@pytest.mark.mpl_image_compare(baseline_images=['separation'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["separation"], extensions=["png"]) @styled def test_separation(): plt.figure() - x0, fs = load_wav('data/separation/ref05/0.wav') - x1, fs = load_wav('data/separation/ref05/1.wav') - x2, fs = load_wav('data/separation/ref05/2.wav') + x0, fs = load_wav("data/separation/ref05/0.wav") + x1, fs = load_wav("data/separation/ref05/1.wav") + x2, fs = load_wav("data/separation/ref05/2.wav") mir_eval.display.separation([x0, x1, x2], fs=fs) -@pytest.mark.mpl_image_compare(baseline_images=['separation_label'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["separation_label"], extensions=["png"]) @styled def test_separation_label(): plt.figure() - x0, fs = load_wav('data/separation/ref05/0.wav') - x1, fs = load_wav('data/separation/ref05/1.wav') - x2, fs = load_wav('data/separation/ref05/2.wav') + x0, fs = load_wav("data/separation/ref05/0.wav") + x1, fs = load_wav("data/separation/ref05/1.wav") + x2, fs = load_wav("data/separation/ref05/2.wav") - mir_eval.display.separation([x0, x1, x2], fs=fs, - labels=['Alice', 'Bob', 'Carol']) + mir_eval.display.separation([x0, x1, x2], fs=fs, labels=["Alice", "Bob", "Carol"]) plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['events'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["events"], extensions=["png"]) @styled def test_events(): plt.figure() # Load some event data - beats_ref = mir_eval.io.load_events('data/beat/ref00.txt')[:30] - beats_est = mir_eval.io.load_events('data/beat/est00.txt')[:30] + beats_ref = mir_eval.io.load_events("data/beat/ref00.txt")[:30] + beats_est = mir_eval.io.load_events("data/beat/est00.txt")[:30] # Plot both with labels - mir_eval.display.events(beats_ref, label='reference') - mir_eval.display.events(beats_est, label='estimate') + mir_eval.display.events(beats_ref, label="reference") + mir_eval.display.events(beats_est, label="estimate") plt.legend() -@pytest.mark.mpl_image_compare(baseline_images=['labeled_events'], extensions=['png']) +@pytest.mark.mpl_image_compare(baseline_images=["labeled_events"], extensions=["png"]) @styled def test_labeled_events(): plt.figure() # Load some event data - beats_ref = mir_eval.io.load_events('data/beat/ref00.txt')[:10] + beats_ref = mir_eval.io.load_events("data/beat/ref00.txt")[:10] - labels = list('abcdefghijklmnop') + labels = list("abcdefghijklmnop") # Plot both with labels mir_eval.display.events(beats_ref, labels) -@raises(ValueError) +@pytest.mark.xfail(raises=ValueError) def test_pianoroll_nopitch_nomidi(): # Issue 214 mir_eval.display.piano_roll([[0, 1]]) diff --git a/tests/test_hierarchy.py b/tests/test_hierarchy.py index a20a902e..882911e3 100644 --- a/tests/test_hierarchy.py +++ b/tests/test_hierarchy.py @@ -1,24 +1,24 @@ -''' +""" Unit tests for mir_eval.hierarchy -''' +""" -from glob import glob +import glob import re -import warnings import json import numpy as np import scipy.sparse import mir_eval - -from nose.tools import raises +import pytest A_TOL = 1e-12 -def test_tmeasure_pass(): +@pytest.mark.parametrize("window", [5, 10, 15, 30, 90, None]) +@pytest.mark.parametrize("frame_size", [0.1, 0.5, 1.0]) +def test_tmeasure_pass(window, frame_size): # The estimate here gets none of the structure correct. ref = [[[0, 30]], [[0, 15], [15, 30]]] # convert to arrays @@ -26,105 +26,75 @@ def test_tmeasure_pass(): est = ref[:1] - def __test(window, frame_size): - # The estimate should get 0 score here - scores = mir_eval.hierarchy.tmeasure(ref, est, - window=window, - frame_size=frame_size) - - for k in scores: - assert k == 0.0 + # The estimate should get 0 score here + scores = mir_eval.hierarchy.tmeasure(ref, est, window=window, frame_size=frame_size) - # The reference should get a perfect score here - scores = mir_eval.hierarchy.tmeasure(ref, ref, - window=window, - frame_size=frame_size) + for k in scores: + assert k == 0.0 - for k in scores: - assert k == 1.0 + # The reference should get a perfect score here + scores = mir_eval.hierarchy.tmeasure(ref, ref, window=window, frame_size=frame_size) - for window in [5, 10, 15, 30, 90, None]: - for frame_size in [0.1, 0.5, 1.0]: - yield __test, window, frame_size + for k in scores: + assert k == 1.0 def test_tmeasure_warning(): - # Warn if there are missing boundaries from one layer to the next - ref = [[[0, 5], - [5, 10]], - [[0, 10]]] + ref = [[[0, 5], [5, 10]], [[0, 10]]] ref = [np.asarray(_) for _ in ref] - warnings.resetwarnings() - warnings.simplefilter('always') - with warnings.catch_warnings(record=True) as out: + with pytest.warns( + UserWarning, match="Segment hierarchy is inconsistent at level 1" + ): mir_eval.hierarchy.tmeasure(ref, ref) - assert len(out) > 0 - assert out[0].category is UserWarning - assert ('Segment hierarchy is inconsistent at level 1' - in str(out[0].message)) - def test_tmeasure_fail_span(): - # Does not start at 0 - ref = [[[1, 10]], - [[1, 5], - [5, 10]]] + ref = [[[1, 10]], [[1, 5], [5, 10]]] ref = [np.asarray(_) for _ in ref] - yield raises(ValueError)(mir_eval.hierarchy.tmeasure), ref, ref + with pytest.raises(ValueError): + mir_eval.hierarchy.tmeasure(ref, ref) # Does not end at the right time - ref = [[[0, 5]], - [[0, 5], - [5, 6]]] + ref = [[[0, 5]], [[0, 5], [5, 6]]] ref = [np.asarray(_) for _ in ref] - yield raises(ValueError)(mir_eval.hierarchy.tmeasure), ref, ref + with pytest.raises(ValueError): + mir_eval.hierarchy.tmeasure(ref, ref) # Two annotaions of different shape - ref = [[[0, 10]], - [[0, 5], - [5, 10]]] + ref = [[[0, 10]], [[0, 5], [5, 10]]] ref = [np.asarray(_) for _ in ref] - est = [[[0, 15]], - [[0, 5], - [5, 15]]] + est = [[[0, 15]], [[0, 5], [5, 15]]] est = [np.asarray(_) for _ in est] - yield raises(ValueError)(mir_eval.hierarchy.tmeasure), ref, est + with pytest.raises(ValueError): + mir_eval.hierarchy.tmeasure(ref, est) -def test_tmeasure_fail_frame_size(): - ref = [[[0, 60]], - [[0, 30], - [30, 60]]] +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "window, frame_size", + [(None, -1), (None, 0), (15, -1), (15, 0), (15, 30), (30, -1), (30, 0), (30, 60)], +) +def test_tmeasure_fail_frame_size(window, frame_size): + ref = [[[0, 60]], [[0, 30], [30, 60]]] ref = [np.asarray(_) for _ in ref] - @raises(ValueError) - def __test(window, frame_size): - mir_eval.hierarchy.tmeasure(ref, ref, - window=window, - frame_size=frame_size) + mir_eval.hierarchy.tmeasure(ref, ref, window=window, frame_size=frame_size) - for window in [None, 15, 30]: - for frame_size in [-1, 0]: - yield __test, window, frame_size - if window is not None: - yield __test, window, 2 * window - - -def test_lmeasure_pass(): +@pytest.mark.parametrize("frame_size", [0.1, 0.5, 1.0]) +def test_lmeasure_pass(frame_size): # The estimate here gets none of the structure correct. ref = [[[0, 30]], [[0, 15], [15, 30]]] - ref_lab = [['A'], ['a', 'b']] + ref_lab = [["A"], ["a", "b"]] # convert to arrays ref = [np.asarray(_) for _ in ref] @@ -132,105 +102,97 @@ def test_lmeasure_pass(): est = ref[:1] est_lab = ref_lab[:1] - def __test(frame_size): - # The estimate should get 0 score here - scores = mir_eval.hierarchy.lmeasure(ref, ref_lab, est, est_lab, - frame_size=frame_size) - - for k in scores: - assert k == 0.0 + # The estimate should get 0 score here + scores = mir_eval.hierarchy.lmeasure( + ref, ref_lab, est, est_lab, frame_size=frame_size + ) - # The reference should get a perfect score here - scores = mir_eval.hierarchy.lmeasure(ref, ref_lab, ref, ref_lab, - frame_size=frame_size) + for k in scores: + assert k == 0.0 - for k in scores: - assert k == 1.0 + # The reference should get a perfect score here + scores = mir_eval.hierarchy.lmeasure( + ref, ref_lab, ref, ref_lab, frame_size=frame_size + ) - for frame_size in [0.1, 0.5, 1.0]: - yield __test, frame_size + for k in scores: + assert k == 1.0 def test_lmeasure_warning(): - # Warn if there are missing boundaries from one layer to the next - ref = [[[0, 5], - [5, 10]], - [[0, 10]]] + ref = [[[0, 5], [5, 10]], [[0, 10]]] ref = [np.asarray(_) for _ in ref] - ref_lab = [['a', 'b'], ['A']] + ref_lab = [["a", "b"], ["A"]] - warnings.resetwarnings() - warnings.simplefilter('always') - with warnings.catch_warnings(record=True) as out: + with pytest.warns( + UserWarning, match="Segment hierarchy is inconsistent at level 1" + ): mir_eval.hierarchy.lmeasure(ref, ref_lab, ref, ref_lab) - assert len(out) > 0 - assert out[0].category is UserWarning - assert ('Segment hierarchy is inconsistent at level 1' - in str(out[0].message)) - def test_lmeasure_fail_span(): - # Does not start at 0 - ref = [[[1, 10]], - [[1, 5], - [5, 10]]] + ref = [[[1, 10]], [[1, 5], [5, 10]]] - ref_lab = [['A'], ['a', 'b']] + ref_lab = [["A"], ["a", "b"]] ref = [np.asarray(_) for _ in ref] - yield (raises(ValueError)(mir_eval.hierarchy.lmeasure), - ref, ref_lab, ref, ref_lab) + with pytest.raises(ValueError): + mir_eval.hierarchy.lmeasure(ref, ref_lab, ref, ref_lab) # Does not end at the right time - ref = [[[0, 5]], - [[0, 5], - [5, 6]]] + ref = [[[0, 5]], [[0, 5], [5, 6]]] ref = [np.asarray(_) for _ in ref] - yield (raises(ValueError)(mir_eval.hierarchy.lmeasure), - ref, ref_lab, ref, ref_lab) + with pytest.raises(ValueError): + mir_eval.hierarchy.lmeasure(ref, ref_lab, ref, ref_lab) # Two annotations of different shape - ref = [[[0, 10]], - [[0, 5], - [5, 10]]] + ref = [[[0, 10]], [[0, 5], [5, 10]]] ref = [np.asarray(_) for _ in ref] - est = [[[0, 15]], - [[0, 5], - [5, 15]]] + est = [[[0, 15]], [[0, 5], [5, 15]]] est = [np.asarray(_) for _ in est] - yield (raises(ValueError)(mir_eval.hierarchy.lmeasure), - ref, ref_lab, est, ref_lab) + with pytest.raises(ValueError): + mir_eval.hierarchy.lmeasure(ref, ref_lab, est, ref_lab) -def test_lmeasure_fail_frame_size(): - ref = [[[0, 60]], - [[0, 30], - [30, 60]]] +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize("frame_size", [-1, 0]) +def test_lmeasure_fail_frame_size(frame_size): + ref = [[[0, 60]], [[0, 30], [30, 60]]] ref = [np.asarray(_) for _ in ref] - ref_lab = [['A'], ['a', 'b']] + ref_lab = [["A"], ["a", "b"]] + + mir_eval.hierarchy.lmeasure(ref, ref_lab, ref, ref_lab, frame_size=frame_size) + + +SCORES_GLOB = "data/hierarchy/output*.json" +sco_files = sorted(glob.glob(SCORES_GLOB)) + - @raises(ValueError) - def __test(frame_size): - mir_eval.hierarchy.lmeasure(ref, ref_lab, ref, ref_lab, - frame_size=frame_size) +@pytest.fixture +def hierarchy_outcomes(request): + sco_f = request.param - for frame_size in [-1, 0]: - yield __test, frame_size + with open(sco_f, "r") as fdesc: + expected_scores = json.load(fdesc) + window = float(re.match(r".*output_w=(\d+).json$", sco_f).groups()[0]) + return expected_scores, window -def test_hierarchy_regression(): - ref_files = sorted(glob('data/hierarchy/ref*.lab')) - est_files = sorted(glob('data/hierarchy/est*.lab')) - out_files = sorted(glob('data/hierarchy/output*.json')) +@pytest.mark.parametrize("hierarchy_outcomes", sco_files, indirect=True) +def test_hierarchy_regression(hierarchy_outcomes): + expected_scores, window = hierarchy_outcomes + + # Hierarchy data is split across multiple lab files for these tests + ref_files = sorted(glob.glob("data/hierarchy/ref*.lab")) + est_files = sorted(glob.glob("data/hierarchy/est*.lab")) ref_hier = [mir_eval.io.load_labeled_intervals(_) for _ in ref_files] est_hier = [mir_eval.io.load_labeled_intervals(_) for _ in est_files] @@ -240,25 +202,16 @@ def test_hierarchy_regression(): est_ints = [seg[0] for seg in est_hier] est_labs = [seg[1] for seg in est_hier] - def __test(w, ref_i, ref_l, est_i, est_l, target): - outputs = mir_eval.hierarchy.evaluate(ref_i, ref_l, - est_i, est_l, - window=w) - - for key in target: - assert np.allclose(target[key], outputs[key], atol=A_TOL) - - for out in out_files: - with open(out, 'r') as fdesc: - target = json.load(fdesc) + outputs = mir_eval.hierarchy.evaluate( + ref_ints, ref_labs, est_ints, est_labs, window=window + ) - # Extract the window parameter - window = float(re.match('.*output_w=(\d+).json$', out).groups()[0]) - yield __test, window, ref_ints, ref_labs, est_ints, est_labs, target + assert outputs.keys() == expected_scores.keys() + for key in expected_scores: + assert np.allclose(expected_scores[key], outputs[key], atol=A_TOL) def test_count_inversions(): - # inversion count = |{(i, j) : a[i] >= b[j]}| a = [2, 4, 6] b = [1, 2, 3, 4] @@ -296,29 +249,30 @@ def test_count_inversions(): def test_meet(): - frame_size = 1 - int_hier = [np.array([[0, 10]]), - np.array([[0, 6], [6, 10]]), - np.array([[0, 2], [2, 4], [4, 6], [6, 8], [8, 10]])] + int_hier = [ + np.array([[0, 10]]), + np.array([[0, 6], [6, 10]]), + np.array([[0, 2], [2, 4], [4, 6], [6, 8], [8, 10]]), + ] - lab_hier = [['X'], - ['A', 'B'], - ['a', 'b', 'a', 'c', 'b']] + lab_hier = [["X"], ["A", "B"], ["a", "b", "a", "c", "b"]] # Target output - meet_truth = np.asarray([ - [3, 3, 2, 2, 3, 3, 1, 1, 1, 1], # (XAa) - [3, 3, 2, 2, 3, 3, 1, 1, 1, 1], # (XAa) - [2, 2, 3, 3, 2, 2, 1, 1, 3, 3], # (XAb) - [2, 2, 3, 3, 2, 2, 1, 1, 3, 3], # (XAb) - [3, 3, 2, 2, 3, 3, 1, 1, 1, 1], # (XAa) - [3, 3, 2, 2, 3, 3, 1, 1, 1, 1], # (XAa) - [1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc) - [1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc) - [1, 1, 3, 3, 1, 1, 2, 2, 3, 3], # (XBb) - [1, 1, 3, 3, 1, 1, 2, 2, 3, 3], # (XBb) - ]) + meet_truth = np.asarray( + [ + [3, 3, 2, 2, 3, 3, 1, 1, 1, 1], # (XAa) + [3, 3, 2, 2, 3, 3, 1, 1, 1, 1], # (XAa) + [2, 2, 3, 3, 2, 2, 1, 1, 3, 3], # (XAb) + [2, 2, 3, 3, 2, 2, 1, 1, 3, 3], # (XAb) + [3, 3, 2, 2, 3, 3, 1, 1, 1, 1], # (XAa) + [3, 3, 2, 2, 3, 3, 1, 1, 1, 1], # (XAa) + [1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc) + [1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc) + [1, 1, 3, 3, 1, 1, 2, 2, 3, 3], # (XBb) + [1, 1, 3, 3, 1, 1, 2, 2, 3, 3], # (XBb) + ] + ) meet = mir_eval.hierarchy._meet(int_hier, lab_hier, frame_size) # Is it the right type? @@ -333,7 +287,6 @@ def test_meet(): def test_compare_frame_rankings(): - # number of pairs (i, j) # where ref[i] < ref[j] and est[i] >= est[j] @@ -345,34 +298,29 @@ def test_compare_frame_rankings(): # Just count the normalizers # No self-inversions are possible from ref to itself - inv, norm = mir_eval.hierarchy._compare_frame_rankings(ref, ref, - transitive=True) + inv, norm = mir_eval.hierarchy._compare_frame_rankings(ref, ref, transitive=True) assert inv == 0 assert norm == 5.0 - inv, norm = mir_eval.hierarchy._compare_frame_rankings(ref, ref, - transitive=False) + inv, norm = mir_eval.hierarchy._compare_frame_rankings(ref, ref, transitive=False) assert inv == 0 assert norm == 3.0 est = np.asarray([1, 2, 1, 3]) # In the transitive case, we lose two pairs # (1, 3) and (2, 2) -> (1, 1), (2, 1) - inv, norm = mir_eval.hierarchy._compare_frame_rankings(ref, est, - transitive=True) + inv, norm = mir_eval.hierarchy._compare_frame_rankings(ref, est, transitive=True) assert inv == 2 assert norm == 5.0 # In the non-transitive case, we only lose one pair # because (1,3) was not counted - inv, norm = mir_eval.hierarchy._compare_frame_rankings(ref, est, - transitive=False) + inv, norm = mir_eval.hierarchy._compare_frame_rankings(ref, est, transitive=False) assert inv == 1 assert norm == 3.0 # Do an all-zeros test ref = np.asarray([1, 1, 1, 1]) - inv, norm = mir_eval.hierarchy._compare_frame_rankings(ref, ref, - transitive=True) + inv, norm = mir_eval.hierarchy._compare_frame_rankings(ref, ref, transitive=True) assert inv == 0 assert norm == 0.0 diff --git a/tests/test_input_output.py b/tests/test_input_output.py index bc3841ad..f3ba3d45 100644 --- a/tests/test_input_output.py +++ b/tests/test_input_output.py @@ -4,52 +4,52 @@ import json import mir_eval import warnings -import nose.tools import tempfile +import pytest def test_load_delimited(): # Test for ValueError when a non-string or file handle is passed - nose.tools.assert_raises( - IOError, mir_eval.io.load_delimited, None, [int]) + with pytest.raises(IOError): + mir_eval.io.load_delimited(None, [int]) + # Test for a value error when the wrong number of columns is passed - with tempfile.TemporaryFile('r+') as f: - f.write('10 20') + with tempfile.TemporaryFile("r+") as f: + f.write("10 20") f.seek(0) - nose.tools.assert_raises( - ValueError, mir_eval.io.load_delimited, f, [int, int, int]) + with pytest.raises(ValueError): + mir_eval.io.load_delimited(f, [int, int, int]) # Test for a value error on conversion failure - with tempfile.TemporaryFile('r+') as f: - f.write('10 a 30') + with tempfile.TemporaryFile("r+") as f: + f.write("10 a 30") f.seek(0) - nose.tools.assert_raises( - ValueError, mir_eval.io.load_delimited, f, [int, int, int]) + with pytest.raises(ValueError): + mir_eval.io.load_delimited(f, [int, int, int]) def test_load_delimited_commented(): - with tempfile.TemporaryFile('r+') as f: - f.write('; some comment\n10 20\n30 50') + with tempfile.TemporaryFile("r+") as f: + f.write("; some comment\n10 20\n30 50") f.seek(0) - col1, col2 = mir_eval.io.load_delimited(f, [int, int], comment=';') + col1, col2 = mir_eval.io.load_delimited(f, [int, int], comment=";") assert np.allclose(col1, [10, 30]) assert np.allclose(col2, [20, 50]) # Rewind and try with the default comment character f.seek(0) - nose.tools.assert_raises( - ValueError, mir_eval.io.load_delimited, f, [int, int]) + with pytest.raises(ValueError): + mir_eval.io.load_delimited(f, [int, int]) # Rewind and try with no comment support f.seek(0) - nose.tools.assert_raises( - ValueError, mir_eval.io.load_delimited, f, [int, int], - comment=None) + with pytest.raises(ValueError): + mir_eval.io.load_delimited(f, [int, int], comment=None) def test_load_delimited_nocomment(): - with tempfile.TemporaryFile('r+') as f: - f.write('10 20\n30 50') + with tempfile.TemporaryFile("r+") as f: + f.write("10 20\n30 50") f.seek(0) col1, col2 = mir_eval.io.load_delimited(f, [int, int]) assert np.allclose(col1, [10, 30]) @@ -57,7 +57,7 @@ def test_load_delimited_nocomment(): # Rewind and try with a different comment char f.seek(0) - col1, col2 = mir_eval.io.load_delimited(f, [int, int], comment=';') + col1, col2 = mir_eval.io.load_delimited(f, [int, int], comment=";") assert np.allclose(col1, [10, 30]) assert np.allclose(col2, [20, 50]) @@ -70,88 +70,69 @@ def test_load_delimited_nocomment(): def test_load_events(): # Test for a warning when invalid events are supplied - with tempfile.TemporaryFile('r+') as f: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + with tempfile.TemporaryFile("r+") as f: + with pytest.warns(UserWarning, match="Events should be in increasing order."): # Non-increasing is invalid - f.write('10\n9') + f.write("10\n9") f.seek(0) events = mir_eval.io.load_events(f) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert (str(w[-1].message) == - 'Events should be in increasing order.') # Make sure events were read in correctly assert np.all(events == [10, 9]) def test_load_labeled_events(): # Test for a value error when invalid labeled events are supplied - with tempfile.TemporaryFile('r+') as f: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + with tempfile.TemporaryFile("r+") as f: + with pytest.warns(UserWarning, match="Events should be in increasing order."): # Non-increasing is invalid - f.write('10 a\n9 b') + f.write("10 a\n9 b") f.seek(0) events, labels = mir_eval.io.load_labeled_events(f) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert (str(w[-1].message) == - 'Events should be in increasing order.') # Make sure events were read in correctly assert np.all(events == [10, 9]) # Make sure labels were read in correctly - assert labels == ['a', 'b'] + assert labels == ["a", "b"] def test_load_intervals(): # Test for a value error when invalid labeled events are supplied - with tempfile.TemporaryFile('r+') as f: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + with tempfile.TemporaryFile("r+") as f: + with pytest.warns( + UserWarning, match="All interval durations must be strictly positive" + ): # Non-increasing is invalid - f.write('10 9\n9 10') + f.write("10 9\n9 10") f.seek(0) intervals = mir_eval.io.load_intervals(f) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert (str(w[-1].message) == - 'All interval durations must be strictly positive') # Make sure intervals were read in correctly assert np.all(intervals == [[10, 9], [9, 10]]) def test_load_labeled_intervals(): # Test for a value error when invalid labeled events are supplied - with tempfile.TemporaryFile('r+') as f: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + with tempfile.TemporaryFile("r+") as f: + with pytest.warns( + UserWarning, match="All interval durations must be strictly positive" + ): # Non-increasing is invalid - f.write('10 9 a\n9 10 b') + f.write("10 9 a\n9 10 b") f.seek(0) intervals, labels = mir_eval.io.load_labeled_intervals(f) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert (str(w[-1].message) == - 'All interval durations must be strictly positive') # Make sure intervals were read in correctly assert np.all(intervals == [[10, 9], [9, 10]]) - assert labels == ['a', 'b'] + assert labels == ["a", "b"] def test_load_valued_intervals(): # Test for a value error when invalid valued events are supplied - with tempfile.TemporaryFile('r+') as f: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + with tempfile.TemporaryFile("r+") as f: + with pytest.warns( + UserWarning, match="All interval durations must be strictly positive" + ): # Non-increasing is invalid - f.write('10 9 5\n9 10 6') + f.write("10 9 5\n9 10 6") f.seek(0) intervals, values = mir_eval.io.load_valued_intervals(f) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert (str(w[-1].message) == - 'All interval durations must be strictly positive') # Make sure intervals were read in correctly assert np.all(intervals == [[10, 9], [9, 10]]) assert np.all(values == [5, 6]) @@ -159,37 +140,33 @@ def test_load_valued_intervals(): def test_load_ragged_time_series(): # Test for ValueError when a non-string or file handle is passed - nose.tools.assert_raises( - IOError, mir_eval.io.load_ragged_time_series, None, float, - header=False) + with pytest.raises(IOError): + mir_eval.io.load_ragged_time_series(None, float, header=False) # Test for a value error on conversion failure - with tempfile.TemporaryFile('r+') as f: - f.write('10 a 30') + with tempfile.TemporaryFile("r+") as f: + f.write("10 a 30") f.seek(0) - nose.tools.assert_raises( - ValueError, mir_eval.io.load_ragged_time_series, f, float, - header=False) + with pytest.raises(ValueError): + mir_eval.io.load_ragged_time_series(f, float, header=False) # Test for a value error on invalid time stamp - with tempfile.TemporaryFile('r+') as f: - f.write('a 10 30') + with tempfile.TemporaryFile("r+") as f: + f.write("a 10 30") f.seek(0) - nose.tools.assert_raises( - ValueError, mir_eval.io.load_ragged_time_series, f, int, - header=False) + with pytest.raises(ValueError): + mir_eval.io.load_ragged_time_series(f, int, header=False) # Test for a value error on invalid time stamp with header - with tempfile.TemporaryFile('r+') as f: - f.write('x y z\na 10 30') + with tempfile.TemporaryFile("r+") as f: + f.write("x y z\na 10 30") f.seek(0) - nose.tools.assert_raises( - ValueError, mir_eval.io.load_ragged_time_series, f, int, - header=True) + with pytest.raises(ValueError): + mir_eval.io.load_ragged_time_series(f, int, header=True) - with tempfile.TemporaryFile('r+') as f: - f.write('#comment\n0 1 2\n3 4\n# comment\n5 6 7') + with tempfile.TemporaryFile("r+") as f: + f.write("#comment\n0 1 2\n3 4\n# comment\n5 6 7") f.seek(0) - times, values = mir_eval.io.load_ragged_time_series(f, int, - header=False, - comment='#') + times, values = mir_eval.io.load_ragged_time_series( + f, int, header=False, comment="#" + ) assert np.allclose(times, [0, 3, 5]) assert np.allclose(values[0], [1, 2]) assert np.allclose(values[1], [4]) @@ -197,41 +174,33 @@ def test_load_ragged_time_series(): # Rewind with a wrong comment string f.seek(0) - nose.tools.assert_raises( - ValueError, mir_eval.io.load_ragged_time_series, f, int, - header=False, comment='%') + with pytest.raises(ValueError): + mir_eval.io.load_ragged_time_series(f, int, header=False, comment="%") # Rewind with no comment string f.seek(0) - nose.tools.assert_raises( - ValueError, mir_eval.io.load_ragged_time_series, f, int, - header=False, comment=None) + with pytest.raises(ValueError): + mir_eval.io.load_ragged_time_series(f, int, header=False, comment=None) def test_load_tempo(): # Test the tempo loader - tempi, weight = mir_eval.io.load_tempo('data/tempo/ref01.lab') + tempi, weight = mir_eval.io.load_tempo("data/tempo/ref01.lab") assert np.allclose(tempi, [60, 120]) assert weight == 0.5 -@nose.tools.raises(ValueError) +@pytest.mark.xfail(raises=ValueError) def test_load_tempo_multiline(): - tempi, weight = mir_eval.io.load_tempo('data/tempo/bad00.lab') + tempi, weight = mir_eval.io.load_tempo("data/tempo/bad00.lab") -@nose.tools.raises(ValueError) +@pytest.mark.xfail(raises=ValueError) def test_load_tempo_badweight(): - tempi, weight = mir_eval.io.load_tempo('data/tempo/bad01.lab') + tempi, weight = mir_eval.io.load_tempo("data/tempo/bad01.lab") def test_load_bad_tempi(): - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - tempi, weight = mir_eval.io.load_tempo('data/tempo/bad02.lab') - - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert ('non-negative numbers' in str(w[-1].message)) + with pytest.warns(UserWarning, match="non-negative numbers"): + tempi, weight = mir_eval.io.load_tempo("data/tempo/bad02.lab") diff --git a/tests/test_key.py b/tests/test_key.py index 51b8ed0c..8fbcd35c 100644 --- a/tests/test_key.py +++ b/tests/test_key.py @@ -1,9 +1,9 @@ -''' +""" Tests for mir_eval.key -''' +""" import mir_eval -import nose.tools +import pytest import glob import json import numpy as np @@ -11,58 +11,52 @@ A_TOL = 1e-12 # Path to the fixture files -REF_GLOB = 'data/key/ref*.txt' -EST_GLOB = 'data/key/est*.txt' -SCORES_GLOB = 'data/key/output*.json' - - -def __unit_test_key_function(metric): - - good_keys = ['C major', 'c major', 'C# major', 'Bb minor', - 'db minor', 'X', 'x', 'C other'] - # All of these are invalid key strings - bad_keys = ['C maj', 'Cb major', 'C', 'K major', 'F## minor' - 'X other', 'x minor'] - - for good_key in good_keys: - for bad_key in bad_keys: - # Should raise an error whether we pass a bad key as ref or est - nose.tools.assert_raises(ValueError, metric, good_key, bad_key) - nose.tools.assert_raises(ValueError, metric, bad_key, good_key) - - for good_key in good_keys: - # When the same key is passed for est and ref, score should be 1 - assert mir_eval.key.weighted_score(good_key, good_key) == 1. - - -def __check_score(sco_f, metric, score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) - - -def test_key_functions(): - # Load in all files in the same order - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - assert len(ref_files) == len(est_files) == len(sco_files) > 0 - - # Unit test all metrics (one for now) - for metric in [mir_eval.key.weighted_score]: - yield __unit_test_key_function, metric - - # Regression tests - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_scores = json.load(f) - # Load in an example key annotation - reference_key = mir_eval.io.load_key(ref_f) - # Load in an example key detector output - estimated_key = mir_eval.io.load_key(est_f) - # Compute scores - scores = mir_eval.key.evaluate(reference_key, estimated_key) - # Compare them - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, scores[metric], - expected_scores[metric]) +REF_GLOB = "data/key/ref*.txt" +EST_GLOB = "data/key/est*.txt" +SCORES_GLOB = "data/key/output*.json" + +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) + +assert len(ref_files) == len(est_files) == len(sco_files) > 0 +file_sets = list(zip(ref_files, est_files, sco_files)) + + +@pytest.mark.parametrize( + "good_key", + ["C major", "c major", "C# major", "Bb minor", "db minor", "X", "x", "C other"], +) +@pytest.mark.parametrize( + "bad_key", ["C maj", "Cb major", "C", "K major", "F## minor" "X other", "x minor"] +) +def test_key_function_fail(good_key, bad_key): + score = mir_eval.key.weighted_score(good_key, good_key) + assert score == 1.0 + + with pytest.raises(ValueError): + mir_eval.key.weighted_score(good_key, bad_key) + with pytest.raises(ValueError): + mir_eval.key.weighted_score(bad_key, good_key) + + +@pytest.fixture +def key_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) + reference_key = mir_eval.io.load_key(ref_f) + estimated_key = mir_eval.io.load_key(est_f) + + return reference_key, estimated_key, expected_scores + + +@pytest.mark.parametrize("key_data", file_sets, indirect=True) +def test_key_functions(key_data): + reference_key, estimated_key, expected_scores = key_data + # Compute scores + scores = mir_eval.key.evaluate(reference_key, estimated_key) + # Compare them + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) diff --git a/tests/test_melody.py b/tests/test_melody.py index bcd2bf5c..60e7f796 100644 --- a/tests/test_melody.py +++ b/tests/test_melody.py @@ -1,35 +1,54 @@ # CREATED: 4/15/14 9:42 AM by Justin Salamon -''' +""" Unit tests for mir_eval.melody -''' +""" import numpy as np import json -import nose.tools import mir_eval import glob -import warnings +import pytest A_TOL = 1e-12 # Path to the fixture files -REF_GLOB = 'data/melody/ref*.txt' -EST_GLOB = 'data/melody/est*.txt' -SCORES_GLOB = 'data/melody/output*.json' +REF_GLOB = "data/melody/ref*.txt" +EST_GLOB = "data/melody/est*.txt" +SCORES_GLOB = "data/melody/output*.json" + +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) + +assert len(ref_files) == len(est_files) == len(sco_files) > 0 + +file_sets = list(zip(ref_files, est_files, sco_files)) + + +@pytest.fixture +def melody_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) + # Load in reference melody + ref_time, ref_freq = mir_eval.io.load_time_series(ref_f) + # Load in estimated melody + est_time, est_freq = mir_eval.io.load_time_series(est_f) + return ref_time, ref_freq, est_time, est_freq, expected_scores def test_hz2cents(): # Unit test some simple values - hz = np.array([0., 10., 5., 320., 1420.31238974231]) + hz = np.array([0.0, 10.0, 5.0, 320.0, 1420.31238974231]) # Expected cent conversion - expected_cent = np.array([0., 0., -1200., 6000., 8580.0773605]) + expected_cent = np.array([0.0, 0.0, -1200.0, 6000.0, 8580.0773605]) assert np.allclose(mir_eval.melody.hz2cents(hz), expected_cent) def test_freq_to_voicing(): # Unit test some simple values - hz = np.array([0., 100., -132.]) - expected_hz = np.array([0., 100., 132.]) + hz = np.array([0.0, 100.0, -132.0]) + expected_hz = np.array([0.0, 100.0, 132.0]) expected_voicing = np.array([0, 1, 0]) # Check voicing conversion res_hz, res_voicing = mir_eval.melody.freq_to_voicing(hz) @@ -37,9 +56,9 @@ def test_freq_to_voicing(): assert np.all(res_voicing == expected_voicing) # Unit test some simple values where voicing is given - hz = np.array([0., 100., -132., 0, 131.]) + hz = np.array([0.0, 100.0, -132.0, 0, 131.0]) voicing = np.array([0.8, 0.0, 1.0, 0.0, 0.5]) - expected_hz = np.array([0., 100., 132., 0., 131.]) + expected_hz = np.array([0.0, 100.0, 132.0, 0.0, 131.0]) expected_voicing = np.array([0.0, 0.0, 1.0, 0.0, 0.5]) # Check voicing conversion res_hz, res_voicing = mir_eval.melody.freq_to_voicing(hz, voicing=voicing) @@ -48,39 +67,37 @@ def test_freq_to_voicing(): def test_constant_hop_timebase(): - hop = .1 - end_time = .35 - expected_times = np.array([0, .1, .2, .3]) + hop = 0.1 + end_time = 0.35 + expected_times = np.array([0, 0.1, 0.2, 0.3]) res_times = mir_eval.melody.constant_hop_timebase(hop, end_time) assert np.allclose(res_times, expected_times) def test_resample_melody_series(): # Check for a small example including a zero transition - times = np.arange(4)/35.0 - cents = np.array([2., 0., -1., 1.]) + times = np.arange(4) / 35.0 + cents = np.array([2.0, 0.0, -1.0, 1.0]) voicing = np.array([1, 0, 1, 1]) - times_new = np.linspace(0, .08, 9) - expected_cents = np.array([2., 2., 2., 0., 0., 0., -.8, -.1, .6]) + times_new = np.linspace(0, 0.08, 9) + expected_cents = np.array([2.0, 2.0, 2.0, 0.0, 0.0, 0.0, -0.8, -0.1, 0.6]) expected_voicing = np.array([1, 1, 1, 0, 0, 0, 1, 1, 1]) - (res_cents, - res_voicing) = mir_eval.melody.resample_melody_series(times, cents, - voicing, times_new) + (res_cents, res_voicing) = mir_eval.melody.resample_melody_series( + times, cents, voicing, times_new + ) assert np.allclose(res_cents, expected_cents) assert np.allclose(res_voicing, expected_voicing) # Check for a small example including a zero transition - nonbinary voicing - times = np.arange(4)/35.0 - cents = np.array([2., 0., -1., 1.]) + times = np.arange(4) / 35.0 + cents = np.array([2.0, 0.0, -1.0, 1.0]) voicing = np.array([0.8, 0.0, 0.2, 1.0]) - times_new = np.linspace(0, .08, 9) - expected_cents = np.array([2., 2., 2., 0., 0., 0., -.8, -.1, .6]) - expected_voicing = np.array( - [0.8, 0.52, 0.24, 0.01, 0.08, 0.15, 0.28, 0.56, 0.84] + times_new = np.linspace(0, 0.08, 9) + expected_cents = np.array([2.0, 2.0, 2.0, 0.0, 0.0, 0.0, -0.8, -0.1, 0.6]) + expected_voicing = np.array([0.8, 0.52, 0.24, 0.01, 0.08, 0.15, 0.28, 0.56, 0.84]) + (res_cents, res_voicing) = mir_eval.melody.resample_melody_series( + times, cents, voicing, times_new ) - (res_cents, - res_voicing) = mir_eval.melody.resample_melody_series(times, cents, - voicing, times_new) assert np.allclose(res_cents, expected_cents) assert np.allclose(res_voicing, expected_voicing) @@ -89,26 +106,26 @@ def test_resample_melody_series_same_times(): # Check the case where the time bases are identical times = np.array([0.0, 0.1, 0.2, 0.3]) times_new = np.array([0.0, 0.1, 0.2, 0.3]) - cents = np.array([2., 0., -1., 1.]) + cents = np.array([2.0, 0.0, -1.0, 1.0]) voicing = np.array([0, 0, 1, 1]) - expected_cents = np.array([2., 0., -1., 1.]) + expected_cents = np.array([2.0, 0.0, -1.0, 1.0]) expected_voicing = np.array([False, False, True, True]) - (res_cents, - res_voicing) = mir_eval.melody.resample_melody_series(times, cents, - voicing, times_new) + (res_cents, res_voicing) = mir_eval.melody.resample_melody_series( + times, cents, voicing, times_new + ) assert np.allclose(res_cents, expected_cents) assert np.allclose(res_voicing, expected_voicing) # Check the case where the time bases are identical - nonbinary voicing times = np.array([0.0, 0.1, 0.2, 0.3]) times_new = np.array([0.0, 0.1, 0.2, 0.3]) - cents = np.array([2., 0., -1., 1.]) + cents = np.array([2.0, 0.0, -1.0, 1.0]) voicing = np.array([0.5, 0.8, 0.9, 1.0]) - expected_cents = np.array([2., 0., -1., 1.]) + expected_cents = np.array([2.0, 0.0, -1.0, 1.0]) expected_voicing = np.array([0.5, 0.8, 0.9, 1.0]) - (res_cents, - res_voicing) = mir_eval.melody.resample_melody_series(times, cents, - voicing, times_new) + (res_cents, res_voicing) = mir_eval.melody.resample_melody_series( + times, cents, voicing, times_new + ) assert np.allclose(res_cents, expected_cents) assert np.allclose(res_voicing, expected_voicing) @@ -119,17 +136,15 @@ def test_to_cent_voicing(): ref_time, ref_freq = mir_eval.io.load_time_series(ref_file) est_file = sorted(glob.glob(EST_GLOB))[0] est_time, est_freq = mir_eval.io.load_time_series(est_file) - ref_v, ref_c, est_v, est_c = mir_eval.melody.to_cent_voicing(ref_time, - ref_freq, - est_time, - est_freq) + ref_v, ref_c, est_v, est_c = mir_eval.melody.to_cent_voicing( + ref_time, ref_freq, est_time, est_freq + ) # Expected values test_range = np.arange(220, 225) expected_ref_v = np.array([False, False, False, True, True]) - expected_ref_c = np.array([0., 0., 0., 6056.8837818916609, - 6028.5504583021921]) - expected_est_v = np.array([False]*5) - expected_est_c = np.array([5351.3179423647571]*5) + expected_ref_c = np.array([0.0, 0.0, 0.0, 6056.8837818916609, 6028.5504583021921]) + expected_est_v = np.array([False] * 5) + expected_est_c = np.array([5351.3179423647571] * 5) assert np.allclose(ref_v[test_range], expected_ref_v) assert np.allclose(ref_c[test_range], expected_ref_c) assert np.allclose(est_v[test_range], expected_est_v) @@ -137,37 +152,40 @@ def test_to_cent_voicing(): # Test that a 0 is added to the beginning for return_item in mir_eval.melody.to_cent_voicing( - np.array([1., 2.]), np.array([440., 442.]), np.array([1., 2.]), - np.array([441., 443.])): + np.array([1.0, 2.0]), + np.array([440.0, 442.0]), + np.array([1.0, 2.0]), + np.array([441.0, 443.0]), + ): assert len(return_item) == 3 assert return_item[0] == return_item[1] # Test custom voicings ref_time, ref_freq = mir_eval.io.load_time_series(ref_file) _, ref_reward = mir_eval.io.load_time_series("data/melody/reward00.txt") - _, est_voicing = mir_eval.io.load_time_series( - "data/melody/voicingest00.txt" + _, est_voicing = mir_eval.io.load_time_series("data/melody/voicingest00.txt") + (ref_v, ref_c, est_v, est_c) = mir_eval.melody.to_cent_voicing( + ref_time, + ref_freq, + est_time, + est_freq, + est_voicing=est_voicing, + ref_reward=ref_reward, ) - (ref_v, ref_c, - est_v, est_c) = mir_eval.melody.to_cent_voicing(ref_time, - ref_freq, - est_time, - est_freq, - est_voicing=est_voicing, - ref_reward=ref_reward) # Expected values test_range = np.arange(220, 225) - expected_ref_v = np.array([0., 0., 0., 1., 0.3]) - expected_ref_c = np.array([0., 0., 0., 6056.8837818916609, - 6028.5504583021921]) + expected_ref_v = np.array([0.0, 0.0, 0.0, 1.0, 0.3]) + expected_ref_c = np.array([0.0, 0.0, 0.0, 6056.8837818916609, 6028.5504583021921]) expected_est_v = np.array([0.2, 0.2, 0.2, 0.2, 0.2]) - expected_est_c = np.array([5351.3179423647571]*5) + expected_est_c = np.array([5351.3179423647571] * 5) assert np.allclose(ref_v[test_range], expected_ref_v) assert np.allclose(ref_c[test_range], expected_ref_c) assert np.allclose(est_v[test_range], expected_est_v) assert np.allclose(est_c[test_range], expected_est_c) +# We can ignore this warning, which occurs when testing with all-zeros reward +@pytest.mark.filterwarnings("ignore:Reference melody has no voiced frames") def test_continuous_voicing_metrics(): ref_time = np.array([0.0, 0.1, 0.2, 0.3]) ref_freq = np.array([440.0, 0.0, 220.0, 220.0]) @@ -187,50 +205,50 @@ def test_continuous_voicing_metrics(): all_expected = [ # perfect { - 'Voicing Recall': 1.0, - 'Voicing False Alarm': 0.0, - 'Raw Pitch Accuracy': 1. / 3., - 'Raw Chroma Accuracy': 2. / 3., - 'Overall Accuracy': 0.5, + "Voicing Recall": 1.0, + "Voicing False Alarm": 0.0, + "Raw Pitch Accuracy": 1.0 / 3.0, + "Raw Chroma Accuracy": 2.0 / 3.0, + "Overall Accuracy": 0.5, }, # all wrong { - 'Voicing Recall': 0.0, - 'Voicing False Alarm': 1.0, - 'Raw Pitch Accuracy': 1. / 3., - 'Raw Chroma Accuracy': 2. / 3., - 'Overall Accuracy': 0.0, + "Voicing Recall": 0.0, + "Voicing False Alarm": 1.0, + "Raw Pitch Accuracy": 1.0 / 3.0, + "Raw Chroma Accuracy": 2.0 / 3.0, + "Overall Accuracy": 0.0, }, # all 0.5 { - 'Voicing Recall': 0.5, - 'Voicing False Alarm': 0.5, - 'Raw Pitch Accuracy': 1. / 3., - 'Raw Chroma Accuracy': 2. / 3., - 'Overall Accuracy': 0.25, + "Voicing Recall": 0.5, + "Voicing False Alarm": 0.5, + "Raw Pitch Accuracy": 1.0 / 3.0, + "Raw Chroma Accuracy": 2.0 / 3.0, + "Overall Accuracy": 0.25, }, # almost right { - 'Voicing Recall': 0.8, - 'Voicing False Alarm': 0.2, - 'Raw Pitch Accuracy': 1. / 3., - 'Raw Chroma Accuracy': 2. / 3., - 'Overall Accuracy': 0.4, + "Voicing Recall": 0.8, + "Voicing False Alarm": 0.2, + "Raw Pitch Accuracy": 1.0 / 3.0, + "Raw Chroma Accuracy": 2.0 / 3.0, + "Overall Accuracy": 0.4, }, # almost wrong { - 'Voicing Recall': 0.2, - 'Voicing False Alarm': 0.8, - 'Raw Pitch Accuracy': 1. / 3., - 'Raw Chroma Accuracy': 2. / 3., - 'Overall Accuracy': 0.1, + "Voicing Recall": 0.2, + "Voicing False Alarm": 0.8, + "Raw Pitch Accuracy": 1.0 / 3.0, + "Raw Chroma Accuracy": 2.0 / 3.0, + "Overall Accuracy": 0.1, }, ] for est_voicing, expected_scores in zip(all_est_voicing, all_expected): - actual_scores = mir_eval.melody.evaluate(ref_time, ref_freq, est_time, - est_freq, - est_voicing=est_voicing) + actual_scores = mir_eval.melody.evaluate( + ref_time, ref_freq, est_time, est_freq, est_voicing=est_voicing + ) for metric in actual_scores: assert np.isclose(actual_scores[metric], expected_scores[metric]) @@ -249,110 +267,129 @@ def test_continuous_voicing_metrics(): all_expected = [ # uniform { - 'Voicing Recall': 1.0, - 'Voicing False Alarm': 0.0, - 'Raw Pitch Accuracy': 1. / 3., - 'Raw Chroma Accuracy': 2. / 3., - 'Overall Accuracy': 0.5, + "Voicing Recall": 1.0, + "Voicing False Alarm": 0.0, + "Raw Pitch Accuracy": 1.0 / 3.0, + "Raw Chroma Accuracy": 2.0 / 3.0, + "Overall Accuracy": 0.5, }, # uniform - different number { - 'Voicing Recall': 1.0, - 'Voicing False Alarm': 0.0, - 'Raw Pitch Accuracy': 1. / 3., - 'Raw Chroma Accuracy': 2. / 3., - 'Overall Accuracy': 0.5, + "Voicing Recall": 1.0, + "Voicing False Alarm": 0.0, + "Raw Pitch Accuracy": 1.0 / 3.0, + "Raw Chroma Accuracy": 2.0 / 3.0, + "Overall Accuracy": 0.5, }, # all zero { - 'Voicing Recall': 1.0, - 'Voicing False Alarm': 0.75, - 'Raw Pitch Accuracy': 0.0, - 'Raw Chroma Accuracy': 0.0, - 'Overall Accuracy': 0.25, + "Voicing Recall": 1.0, + "Voicing False Alarm": 0.75, + "Raw Pitch Accuracy": 0.0, + "Raw Chroma Accuracy": 0.0, + "Overall Accuracy": 0.25, }, # one weight { - 'Voicing Recall': 1.0, - 'Voicing False Alarm': 2. / 3., - 'Raw Pitch Accuracy': 1.0, - 'Raw Chroma Accuracy': 1.0, - 'Overall Accuracy': 0.5, + "Voicing Recall": 1.0, + "Voicing False Alarm": 2.0 / 3.0, + "Raw Pitch Accuracy": 1.0, + "Raw Chroma Accuracy": 1.0, + "Overall Accuracy": 0.5, }, # two weights { - 'Voicing Recall': 1.0, - 'Voicing False Alarm': 0.5, - 'Raw Pitch Accuracy': 0.5, - 'Raw Chroma Accuracy': 1.0, - 'Overall Accuracy': 0.5, + "Voicing Recall": 1.0, + "Voicing False Alarm": 0.5, + "Raw Pitch Accuracy": 0.5, + "Raw Chroma Accuracy": 1.0, + "Overall Accuracy": 0.5, }, # slightly generous { - 'Voicing Recall': 1.0, - 'Voicing False Alarm': 0.0, - 'Raw Pitch Accuracy': 0.5, - 'Raw Chroma Accuracy': 0.75, - 'Overall Accuracy': 0.625, + "Voicing Recall": 1.0, + "Voicing False Alarm": 0.0, + "Raw Pitch Accuracy": 0.5, + "Raw Chroma Accuracy": 0.75, + "Overall Accuracy": 0.625, }, # big penalty { - 'Voicing Recall': 1.0, - 'Voicing False Alarm': 0.0, - 'Raw Pitch Accuracy': 0.1, - 'Raw Chroma Accuracy': 0.2, - 'Overall Accuracy': 0.325, + "Voicing Recall": 1.0, + "Voicing False Alarm": 0.0, + "Raw Pitch Accuracy": 0.1, + "Raw Chroma Accuracy": 0.2, + "Overall Accuracy": 0.325, }, ] for ref_reward, expected_scores in zip(all_rewards, all_expected): - actual_scores = mir_eval.melody.evaluate(ref_time, ref_freq, est_time, - est_freq, - est_voicing=est_voicing, - ref_reward=ref_reward) + actual_scores = mir_eval.melody.evaluate( + ref_time, + ref_freq, + est_time, + est_freq, + est_voicing=est_voicing, + ref_reward=ref_reward, + ) for metric in actual_scores: assert np.isclose(actual_scores[metric], expected_scores[metric]) -def __unit_test_voicing_measures(metric): +def test_voicing_measures_empty(): # We need a special test for voicing_measures because it only takes 2 args - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + with pytest.warns() as w: # First, test for warnings due to empty voicing arrays - score = metric(np.array([]), np.array([])) - assert len(w) == 4 - assert np.all([issubclass(wrn.category, UserWarning) for wrn in w]) - assert [str(wrn.message) - for wrn in w] == ["Reference voicing array is empty.", - "Estimated voicing array is empty.", - "Reference melody has no voiced frames.", - "Estimated melody has no voiced frames."] - # And that the metric is 0 - assert np.allclose(score, 0) + score = mir_eval.melody.voicing_measures(np.array([]), np.array([])) + assert len(w) == 4 + assert np.all([issubclass(wrn.category, UserWarning) for wrn in w]) + assert [str(wrn.message) for wrn in w] == [ + "Reference voicing array is empty.", + "Estimated voicing array is empty.", + "Reference melody has no voiced frames.", + "Estimated melody has no voiced frames.", + ] + # And that the metric is 0 + assert np.allclose(score, 0) + + +def test_voicing_measures_unvoiced(): + with pytest.warns() as w: # Also test for a warning when the arrays have non-voiced content - metric(np.ones(10), np.zeros(10)) - assert len(w) == 5 + mir_eval.melody.voicing_measures(np.ones(10), np.zeros(10)) + assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) assert str(w[-1].message) == "Estimated melody has no voiced frames." - # Now test validation function - voicing arrays must be the same size - nose.tools.assert_raises(ValueError, metric, np.ones(10), np.ones(12)) + +@pytest.mark.xfail(raises=ValueError) +def test_melody_voicing_badlength(): + # ref and est voicings must be the same length + mir_eval.melody.voicing_measures(np.ones(10), np.ones(11)) -def __unit_test_melody_function(metric): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') +@pytest.mark.parametrize( + "metric", + [ + mir_eval.melody.raw_pitch_accuracy, + mir_eval.melody.raw_chroma_accuracy, + mir_eval.melody.overall_accuracy, + ], +) +def test_melody_function_empty(metric): + with pytest.warns() as w: # First, test for warnings due to empty voicing arrays score = metric(np.array([]), np.array([]), np.array([]), np.array([])) assert len(w) == 6 assert np.all([issubclass(wrn.category, UserWarning) for wrn in w]) - assert [str(wrn.message) - for wrn in w] == ["Reference voicing array is empty.", - "Estimated voicing array is empty.", - "Reference melody has no voiced frames.", - "Estimated melody has no voiced frames.", - "Reference frequency array is empty.", - "Estimated frequency array is empty."] + assert [str(wrn.message) for wrn in w] == [ + "Reference voicing array is empty.", + "Estimated voicing array is empty.", + "Reference melody has no voiced frames.", + "Estimated melody has no voiced frames.", + "Reference frequency array is empty.", + "Estimated frequency array is empty.", + ] # And that the metric is 0 assert np.allclose(score, 0) # Also test for a warning when the arrays have non-voiced content @@ -361,80 +398,43 @@ def __unit_test_melody_function(metric): assert issubclass(w[-1].category, UserWarning) assert str(w[-1].message) == "Estimated melody has no voiced frames." - # Now test validation function - all inputs must be same length - nose.tools.assert_raises(ValueError, metric, np.ones(10), - np.ones(12), np.ones(10), np.ones(10)) - - -def __check_score(sco_f, metric, score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) - - -def test_melody_functions(): - # Load in all files in the same order - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - assert len(ref_files) == len(est_files) == len(sco_files) > 0 - - # Unit tests - for metric in [mir_eval.melody.voicing_measures, - mir_eval.melody.raw_pitch_accuracy, - mir_eval.melody.raw_chroma_accuracy, - mir_eval.melody.overall_accuracy]: - if metric == mir_eval.melody.voicing_measures: - yield (__unit_test_voicing_measures, metric) - else: - yield (__unit_test_melody_function, metric) - # Regression tests - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_scores = json.load(f) - # Load in reference melody - ref_time, ref_freq = mir_eval.io.load_time_series(ref_f) - # Load in estimated melody - est_time, est_freq = mir_eval.io.load_time_series(est_f) - scores = mir_eval.melody.evaluate(ref_time, ref_freq, est_time, - est_freq) - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, scores[metric], - expected_scores[metric]) - - -def test_melody_functions_continuous_voicing_equivalence(): - # Load in all files in the same order - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - assert len(ref_files) == len(est_files) == len(sco_files) > 0 - - # Unit tests - for metric in [mir_eval.melody.voicing_measures, - mir_eval.melody.raw_pitch_accuracy, - mir_eval.melody.raw_chroma_accuracy, - mir_eval.melody.overall_accuracy]: - if metric == mir_eval.melody.voicing_measures: - yield (__unit_test_voicing_measures, metric) - else: - yield (__unit_test_melody_function, metric) - # Regression tests - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_scores = json.load(f) - # Load in reference melody - ref_time, ref_freq = mir_eval.io.load_time_series(ref_f) - ref_reward = np.ones(ref_time.shape) # uniform reward - # Load in estimated melody - est_time, est_freq = mir_eval.io.load_time_series(est_f) - # voicing equivalent from frequency - est_voicing = (est_freq >= 0).astype('float') - scores = mir_eval.melody.evaluate(ref_time, ref_freq, est_time, - est_freq, est_voicing=est_voicing, - ref_reward=ref_reward) - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, scores[metric], - expected_scores[metric]) + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "metric", + [ + mir_eval.melody.raw_pitch_accuracy, + mir_eval.melody.raw_chroma_accuracy, + mir_eval.melody.overall_accuracy, + ], +) +@pytest.mark.parametrize( + "ref_freq, est_freq", [(np.ones(11), np.ones(10)), (np.ones(10), np.ones(11))] +) +def test_melody_badlength(metric, ref_freq, est_freq): + # frequency and time must be the same length + metric(np.ones(10), ref_freq, np.ones(10), est_freq) + + +@pytest.mark.parametrize("melody_data", file_sets, indirect=True) +@pytest.mark.parametrize("voicing", [False, True]) +def test_melody_functions(melody_data, voicing): + ref_time, ref_freq, est_time, est_freq, expected_scores = melody_data + # When voicing=True, do the continuous voicing equivalence check + if voicing: + ref_reward = np.ones_like(ref_time) + est_voicing = (est_freq >= 0).astype(float) + else: + ref_reward = None + est_voicing = None + scores = mir_eval.melody.evaluate( + ref_time, + ref_freq, + est_time, + est_freq, + est_voicing=est_voicing, + ref_reward=ref_reward, + ) + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) diff --git a/tests/test_multipitch.py b/tests/test_multipitch.py index eccf52bc..514fc884 100644 --- a/tests/test_multipitch.py +++ b/tests/test_multipitch.py @@ -7,14 +7,34 @@ import mir_eval import glob import warnings -import nose.tools +import pytest A_TOL = 1e-12 # Path to the fixture files -REF_GLOB = 'data/multipitch/ref*.txt' -EST_GLOB = 'data/multipitch/est*.txt' -SCORES_GLOB = 'data/multipitch/output*.json' +REF_GLOB = "data/multipitch/ref*.txt" +EST_GLOB = "data/multipitch/est*.txt" +SCORES_GLOB = "data/multipitch/output*.json" + +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) + +assert len(ref_files) == len(est_files) == len(sco_files) > 0 +file_sets = list(zip(ref_files, est_files, sco_files)) + + +@pytest.fixture +def multipitch_data(request): + ref_f, est_f, sco_f = request.param + + with open(sco_f, "r") as f_handle: + expected_score = json.load(f_handle) + + ref_times, ref_freqs = mir_eval.io.load_ragged_time_series(ref_f) + est_times, est_freqs = mir_eval.io.load_ragged_time_series(est_f) + + return ref_times, ref_freqs, est_times, est_freqs, expected_score def __frequencies_equal(freqs_a, freqs_b): @@ -47,10 +67,10 @@ def test_resample_multif0(): times = np.array([0.00, 0.01, 0.02, 0.03]) empty_times = np.array([]) freqs = [ - np.array([200.]), + np.array([200.0]), np.array([]), - np.array([300., 400., 500.]), - np.array([300., 500.]) + np.array([300.0, 400.0, 500.0]), + np.array([300.0, 500.0]), ] empty_freqs = [] target_times1 = times @@ -59,23 +79,21 @@ def test_resample_multif0(): expected_freqs1 = freqs expected_freqs2 = [ - np.array([200.]), - np.array([200.]), + np.array([200.0]), + np.array([200.0]), + np.array([]), + np.array([300.0, 500.0]), np.array([]), - np.array([300., 500.]), - np.array([]) ] expected_freqs3 = empty_freqs - expected_freqs4 = [np.array([])]*4 - - actual_freqs1 = mir_eval.multipitch.resample_multipitch( - times, freqs, target_times1) - actual_freqs2 = mir_eval.multipitch.resample_multipitch( - times, freqs, target_times2) - actual_freqs3 = mir_eval.multipitch.resample_multipitch( - times, freqs, target_times3) + expected_freqs4 = [np.array([])] * 4 + + actual_freqs1 = mir_eval.multipitch.resample_multipitch(times, freqs, target_times1) + actual_freqs2 = mir_eval.multipitch.resample_multipitch(times, freqs, target_times2) + actual_freqs3 = mir_eval.multipitch.resample_multipitch(times, freqs, target_times3) actual_freqs4 = mir_eval.multipitch.resample_multipitch( - empty_times, empty_freqs, target_times1) + empty_times, empty_freqs, target_times1 + ) assert __frequencies_equal(actual_freqs1, expected_freqs1) assert __frequencies_equal(actual_freqs2, expected_freqs2) @@ -85,16 +103,16 @@ def test_resample_multif0(): def test_frequencies_to_midi(): frequencies = [ - np.array([440.]), + np.array([440.0]), np.array([]), - np.array([220., 660., 512.]), - np.array([300., 512.]) + np.array([220.0, 660.0, 512.0]), + np.array([300.0, 512.0]), ] expected = [ - np.array([69.]), + np.array([69.0]), np.array([]), - np.array([57., 76.01955000865388, 71.623683437704088]), - np.array([62.369507723654657, 71.623683437704088]) + np.array([57.0, 76.01955000865388, 71.623683437704088]), + np.array([62.369507723654657, 71.623683437704088]), ] actual = mir_eval.multipitch.frequencies_to_midi(frequencies) assert __frequencies_equal(actual, expected) @@ -102,16 +120,16 @@ def test_frequencies_to_midi(): def test_midi_to_chroma(): midi_frequencies = [ - np.array([69.]), + np.array([69.0]), np.array([]), - np.array([57., 76.01955000865388, 71.623683437704088]), - np.array([62.369507723654657, 71.623683437704088]) + np.array([57.0, 76.01955000865388, 71.623683437704088]), + np.array([62.369507723654657, 71.623683437704088]), ] expected = [ - np.array([9.]), + np.array([9.0]), np.array([]), - np.array([9., 4.01955000865388, 11.623683437704088]), - np.array([2.3695077236546567, 11.623683437704088]) + np.array([9.0, 4.01955000865388, 11.623683437704088]), + np.array([2.3695077236546567, 11.623683437704088]), ] actual = mir_eval.multipitch.midi_to_chroma(midi_frequencies) assert __frequencies_equal(actual, expected) @@ -119,10 +137,10 @@ def test_midi_to_chroma(): def test_compute_num_freqs(): frequencies = [ - np.array([256.]), + np.array([256.0]), np.array([]), - np.array([362.03867196751236, 128., 512.]), - np.array([300., 512.]) + np.array([362.03867196751236, 128.0, 512.0]), + np.array([300.0, 512.0]), ] expected = np.array([1, 0, 3, 2]) actual = mir_eval.multipitch.compute_num_freqs(frequencies) @@ -131,41 +149,41 @@ def test_compute_num_freqs(): def test_compute_num_true_positives(): ref_freqs = [ - np.array([96., 100.]), + np.array([96.0, 100.0]), np.array([]), - np.array([81.]), - np.array([102., 84., 108.]), - np.array([98.745824285950576, 108.]) + np.array([81.0]), + np.array([102.0, 84.0, 108.0]), + np.array([98.745824285950576, 108.0]), ] est_freqs = [ - np.array([96.]), + np.array([96.0]), np.array([]), - np.array([200., 82.]), - np.array([102., 84., 108.]), - np.array([99., 108.]) + np.array([200.0, 82.0]), + np.array([102.0, 84.0, 108.0]), + np.array([99.0, 108.0]), ] expected = np.array([1, 0, 0, 3, 2]) - actual = mir_eval.multipitch.compute_num_true_positives( - ref_freqs, est_freqs) + actual = mir_eval.multipitch.compute_num_true_positives(ref_freqs, est_freqs) assert np.allclose(actual, expected, atol=A_TOL) ref_freqs_chroma = [ - np.array([0., 1.5]), + np.array([0.0, 1.5]), np.array([]), - np.array([2.]), - np.array([5.1, 6., 11.]), - np.array([11.9, 11.9]) + np.array([2.0]), + np.array([5.1, 6.0, 11.0]), + np.array([11.9, 11.9]), ] est_freqs_chroma = [ - np.array([0.]), + np.array([0.0]), np.array([]), - np.array([5., 2.6]), - np.array([5.1, 6., 11.]), - np.array([0.2, 11.5]) + np.array([5.0, 2.6]), + np.array([5.1, 6.0, 11.0]), + np.array([0.2, 11.5]), ] expected = np.array([1, 0, 0, 3, 2]) actual = mir_eval.multipitch.compute_num_true_positives( - ref_freqs_chroma, est_freqs_chroma, chroma=True) + ref_freqs_chroma, est_freqs_chroma, chroma=True + ) assert np.allclose(actual, expected, atol=A_TOL) @@ -178,10 +196,11 @@ def test_accuracy_metrics(): expected_recall = 0.75 expected_accuracy = 0.6 - (actual_precision, - actual_recall, - actual_accuarcy) = mir_eval.multipitch.compute_accuracy( - true_positives, n_ref, n_est) + ( + actual_precision, + actual_recall, + actual_accuarcy, + ) = mir_eval.multipitch.compute_accuracy(true_positives, n_ref, n_est) assert np.allclose(actual_precision, expected_precision, atol=A_TOL) assert np.allclose(actual_recall, expected_recall, atol=A_TOL) @@ -198,11 +217,12 @@ def test_error_score_metrics(): expected_efa = 0.125 expected_etot = 0.375 - (actual_esub, - actual_emiss, - actual_efa, - actual_etot) = mir_eval.multipitch.compute_err_score( - true_positives, n_ref, n_est) + ( + actual_esub, + actual_emiss, + actual_efa, + actual_etot, + ) = mir_eval.multipitch.compute_err_score(true_positives, n_ref, n_est) assert np.allclose(actual_esub, expected_esub, atol=A_TOL) assert np.allclose(actual_emiss, expected_emiss, atol=A_TOL) @@ -213,32 +233,30 @@ def test_error_score_metrics(): def unit_test_metrics(): empty_array = np.array([]) ref_time = np.array([0.0, 0.1]) - ref_freqs = [np.array([201.]), np.array([])] + ref_freqs = [np.array([201.0]), np.array([])] est_time = np.array([0.0, 0.1]) - est_freqs = [np.array([200.]), np.array([])] + est_freqs = [np.array([200.0]), np.array([])] # ref sizes unequal - nose.tools.assert_raises( - ValueError, mir_eval.multipitch.metrics, - np.array([0.0]), ref_freqs, est_time, est_freqs) + with pytest.raises(ValueError): + mir_eval.multipitch.metrics(np.array([0.0]), ref_freqs, est_time, est_freqs) # est sizes unequal - nose.tools.assert_raises( - ValueError, mir_eval.multipitch.metrics, - ref_time, ref_freqs, np.array([0.0]), est_freqs) + with pytest.raises(ValueError): + mir_eval.multipitch.metrics(ref_time, ref_freqs, np.array([0.0]), est_freqs) with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") # Test for warnings on empty values actual_score = mir_eval.multipitch.metrics( - ref_time, [empty_array, empty_array], - est_time, [empty_array, empty_array]) + ref_time, [empty_array, empty_array], est_time, [empty_array, empty_array] + ) assert len(w) == 6 assert issubclass(w[-1].category, UserWarning) assert str(w[-1].message) == "Reference frequencies are all empty." with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") # Test for warnings on empty values # test all inputs empty mir_eval.multipitch.metrics(empty_array, [], empty_array, []) @@ -247,7 +265,7 @@ def unit_test_metrics(): assert str(w[-1].message) == "Reference frequencies are all empty." with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") # reference empty mir_eval.multipitch.metrics(empty_array, [], est_time, est_freqs) assert len(w) == 9 @@ -255,43 +273,61 @@ def unit_test_metrics(): assert str(w[-1].message) == "Reference frequencies are all empty." with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") # estimate empty mir_eval.multipitch.metrics(ref_time, ref_freqs, empty_array, []) assert len(w) == 5 assert issubclass(w[-1].category, UserWarning) assert str(w[-1].message) == "Estimate frequencies are all empty." - expected_score = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + expected_score = ( + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ) assert np.allclose(actual_score, expected_score) # test perfect estimate ref_time = np.array([0.0, 0.1, 0.2]) - ref_freqs = [np.array([201.]), np.array([]), np.array([300.5, 87.1])] - actual_score = mir_eval.multipitch.metrics( - ref_time, ref_freqs, ref_time, ref_freqs) - - expected_score = (1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, - 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0) + ref_freqs = [np.array([201.0]), np.array([]), np.array([300.5, 87.1])] + actual_score = mir_eval.multipitch.metrics(ref_time, ref_freqs, ref_time, ref_freqs) + + expected_score = ( + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + ) assert np.allclose(actual_score, expected_score) -def regression_test_evaluate(): - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - score_files = sorted(glob.glob(SCORES_GLOB)) - - assert len(ref_files) == len(est_files) == len(score_files) > 0 - - for ref_f, est_f, score_f in zip(ref_files, est_files, score_files): - with open(score_f, 'r') as f_handle: - expected_score = json.load(f_handle) - - ref_times, ref_freqs = mir_eval.io.load_ragged_time_series(ref_f) - est_times, est_freqs = mir_eval.io.load_ragged_time_series(est_f) +@pytest.mark.parametrize("multipitch_data", file_sets, indirect=True) +def test_evaluate_regression(multipitch_data): + ref_times, ref_freqs, est_times, est_freqs, expected_score = multipitch_data - actual_score = mir_eval.multipitch.evaluate( - ref_times, ref_freqs, est_times, est_freqs) + actual_score = mir_eval.multipitch.evaluate( + ref_times, ref_freqs, est_times, est_freqs + ) - assert __scores_equal(actual_score, expected_score) + assert __scores_equal(actual_score, expected_score) diff --git a/tests/test_onset.py b/tests/test_onset.py index a4d9188c..cdeaf560 100644 --- a/tests/test_onset.py +++ b/tests/test_onset.py @@ -1,79 +1,79 @@ -''' +""" Unit tests for mir_eval.onset -''' +""" import numpy as np +import pytest import json import mir_eval import glob import warnings -import nose.tools A_TOL = 1e-12 # Path to the fixture files -REF_GLOB = 'data/onset/ref*.txt' -EST_GLOB = 'data/onset/est*.txt' -SCORES_GLOB = 'data/onset/output*.json' - - -def __unit_test_onset_function(metric): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - # First, test for a warning on empty onsets - metric(np.array([]), np.arange(10)) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == "Reference onsets are empty." - metric(np.arange(10), np.array([])) - assert len(w) == 2 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == "Estimated onsets are empty." - # And that the metric is 0 - assert np.allclose(metric(np.array([]), np.array([])), 0) - - # Now test validation function - onsets must be 1d ndarray - onsets = np.array([[1., 2.]]) - nose.tools.assert_raises(ValueError, metric, onsets, onsets) - # onsets must be in seconds (so not huge) - onsets = np.array([1e10, 1e11]) - nose.tools.assert_raises(ValueError, metric, onsets, onsets) - # onsets must be sorted - onsets = np.array([2., 1.]) - nose.tools.assert_raises(ValueError, metric, onsets, onsets) +REF_GLOB = "data/onset/ref*.txt" +EST_GLOB = "data/onset/est*.txt" +SCORES_GLOB = "data/onset/output*.json" +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) + +assert len(ref_files) == len(est_files) == len(sco_files) > 0 + +file_sets = list(zip(ref_files, est_files, sco_files)) + + +@pytest.fixture +def onset_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) + reference_onsets = mir_eval.io.load_events(ref_f) + estimated_onsets = mir_eval.io.load_events(est_f) + + return reference_onsets, estimated_onsets, expected_scores + + +def test_onset_empty_warnings(): + with pytest.warns(UserWarning, match="Reference onsets are empty."): + mir_eval.onset.f_measure(np.array([]), np.arange(10)) + + with pytest.warns(UserWarning, match="Estimated onsets are empty."): + mir_eval.onset.f_measure(np.arange(10), np.array([])) + + with pytest.warns(UserWarning, match="onsets are empty"): + # Also verify that the score is 0 + assert np.allclose(mir_eval.onset.f_measure(np.array([]), np.array([])), 0) + + +@pytest.mark.xfail(raisses=ValueError) +@pytest.mark.parametrize( + "onsets", + [ + np.array([[1.0, 2.0]]), # must be 1d ndarray + np.array([1e10, 1e11]), # must not be huge + np.array([2.0, 1.0]), # must be sorted + ], +) +def test_onset_fail(onsets): + mir_eval.onset.f_measure(onsets, onsets) + + +def test_onset_match(): # Valid onsets which are the same produce a score of 1 for all metrics onsets = np.arange(10, dtype=np.float64) - assert np.allclose(metric(onsets, onsets), 1) - - -def __check_score(sco_f, metric, score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) - - -def test_onset_functions(): - # Load in all files in the same order - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - assert len(ref_files) == len(est_files) == len(sco_files) > 0 - - # Unit tests - for metric in [mir_eval.onset.f_measure]: - yield (__unit_test_onset_function, metric) - # Regression tests - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_scores = json.load(f) - # Load in an example onset annotation - reference_onsets = mir_eval.io.load_events(ref_f) - # Load in an example onset tracker output - estimated_onsets = mir_eval.io.load_events(est_f) - # Compute scores - scores = mir_eval.onset.evaluate(reference_onsets, estimated_onsets) - # Compare them - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, scores[metric], - expected_scores[metric]) + assert np.allclose(mir_eval.onset.f_measure(onsets, onsets), 1.0) + + +@pytest.mark.parametrize("onset_data", file_sets, indirect=True) +def test_onset_functions(onset_data): + reference_onsets, estimated_onsets, expected_scores = onset_data + + # Compute scores + scores = mir_eval.onset.evaluate(reference_onsets, estimated_onsets) + # Compare them + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) diff --git a/tests/test_pattern.py b/tests/test_pattern.py index 2d348e46..e1b4d4b3 100644 --- a/tests/test_pattern.py +++ b/tests/test_pattern.py @@ -6,76 +6,105 @@ import json import mir_eval import glob -import warnings -import nose.tools +import pytest A_TOL = 1e-12 # Path to the fixture files -REF_GLOB = 'data/pattern/ref*.txt' -EST_GLOB = 'data/pattern/est*.txt' -SCORES_GLOB = 'data/pattern/output*.json' +REF_GLOB = "data/pattern/ref*.txt" +EST_GLOB = "data/pattern/est*.txt" +SCORES_GLOB = "data/pattern/output*.json" +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) -def __unit_test_pattern_function(metric): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - # First, test for a warning on empty pattern +assert len(ref_files) == len(est_files) == len(sco_files) > 0 + +file_sets = list(zip(ref_files, est_files, sco_files)) + + +@pytest.fixture +def pattern_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) + reference_patterns = mir_eval.io.load_patterns(ref_f) + estimated_patterns = mir_eval.io.load_patterns(est_f) + + return reference_patterns, estimated_patterns, expected_scores + + +@pytest.mark.parametrize( + "metric", + [ + mir_eval.pattern.standard_FPR, + mir_eval.pattern.establishment_FPR, + mir_eval.pattern.occurrence_FPR, + mir_eval.pattern.three_layer_FPR, + mir_eval.pattern.first_n_three_layer_P, + mir_eval.pattern.first_n_target_proportion_R, + ], +) +def test_pattern_empty(metric): + # First, test for a warning on empty pattern + with pytest.warns(UserWarning, match="Reference patterns are empty"): metric([[[]]], [[[(100, 20)]]]) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == 'Reference patterns are empty.' + + with pytest.warns(UserWarning, match="Estimated patterns are empty"): metric([[[(100, 20)]]], [[[]]]) - assert len(w) == 2 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == "Estimated patterns are empty." + + with pytest.warns(UserWarning, match="patterns are empty"): # And that the metric is 0 assert np.allclose(metric([[[]]], [[[]]]), 0) - # Now test validation function - patterns must contain at least 1 occ - patterns = [[[(100, 20)]], []] - nose.tools.assert_raises(ValueError, metric, patterns, patterns) - # The (onset, midi) tuple must contain 2 elements - patterns = [[[(100, 20, 3)]]] - nose.tools.assert_raises(ValueError, metric, patterns, patterns) +@pytest.mark.parametrize( + "metric", + [ + mir_eval.pattern.standard_FPR, + mir_eval.pattern.establishment_FPR, + mir_eval.pattern.occurrence_FPR, + mir_eval.pattern.three_layer_FPR, + mir_eval.pattern.first_n_three_layer_P, + mir_eval.pattern.first_n_target_proportion_R, + ], +) +@pytest.mark.parametrize( + "patterns", + [ + [[[(100, 20)]], []], # patterns must have at least one occurrence + [[[(100, 20, 3)]]], # (onset, midi) tuple must contain 2 elements + ], +) +@pytest.mark.xfail(raises=ValueError) +def test_pattern_failure(metric, patterns): + metric(patterns, patterns) + + +@pytest.mark.parametrize( + "metric", + [ + mir_eval.pattern.standard_FPR, + mir_eval.pattern.establishment_FPR, + mir_eval.pattern.occurrence_FPR, + mir_eval.pattern.three_layer_FPR, + mir_eval.pattern.first_n_three_layer_P, + mir_eval.pattern.first_n_target_proportion_R, + ], +) +def test_pattern_perfect(metric): # Valid patterns which are the same produce a score of 1 for all metrics patterns = [[[(100, 20), (200, 30)]]] assert np.allclose(metric(patterns, patterns), 1) -def __check_score(sco_f, metric, score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) - - -def test_pattern_functions(): - # Load in all files in the same order - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - assert len(ref_files) == len(est_files) == len(sco_files) > 0 - - # Unit tests - for metric in [mir_eval.pattern.standard_FPR, - mir_eval.pattern.establishment_FPR, - mir_eval.pattern.occurrence_FPR, - mir_eval.pattern.three_layer_FPR, - mir_eval.pattern.first_n_three_layer_P, - mir_eval.pattern.first_n_target_proportion_R]: - yield (__unit_test_pattern_function, metric) - # Regression tests - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_scores = json.load(f) - # Load in reference and estimated patterns - reference_patterns = mir_eval.io.load_patterns(ref_f) - estimated_patterns = mir_eval.io.load_patterns(est_f) - # Compute scores - scores = mir_eval.pattern.evaluate(reference_patterns, - estimated_patterns) - # Compare them - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, scores[metric], - expected_scores[metric]) +@pytest.mark.parametrize("pattern_data", file_sets, indirect=True) +def test_pattern_functions(pattern_data): + reference_patterns, estimated_patterns, expected_scores = pattern_data + # Compute scores + scores = mir_eval.pattern.evaluate(reference_patterns, estimated_patterns) + # Compare them + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) diff --git a/tests/test_segment.py b/tests/test_segment.py index b46dd1c4..8c19618f 100644 --- a/tests/test_segment.py +++ b/tests/test_segment.py @@ -1,175 +1,216 @@ -''' +""" Unit tests for mir_eval.segment -''' +""" import numpy as np import json import mir_eval import glob -import warnings -import nose.tools +import pytest A_TOL = 1e-12 # Path to the fixture files -REF_GLOB = 'data/segment/ref*.lab' -EST_GLOB = 'data/segment/est*.lab' -SCORES_GLOB = 'data/segment/output*.json' +REF_GLOB = "data/segment/ref*.lab" +EST_GLOB = "data/segment/est*.lab" +SCORES_GLOB = "data/segment/output*.json" +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) -def __unit_test_boundary_function(metric): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - # Test for warning when empty intervals with no trimming +assert len(ref_files) == len(est_files) == len(sco_files) > 0 + +file_sets = list(zip(ref_files, est_files, sco_files)) + + +@pytest.fixture +def segment_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) + # Load in an example segmentation annotation + ref_intervals, ref_labels = mir_eval.io.load_labeled_intervals(ref_f) + # Load in an example segmentation tracker output + est_intervals, est_labels = mir_eval.io.load_labeled_intervals(est_f) + + return ref_intervals, ref_labels, est_intervals, est_labels, expected_scores + + +@pytest.mark.parametrize( + "metric", [mir_eval.segment.detection, mir_eval.segment.deviation] +) +def test_segment_boundary_empty(metric): + with pytest.warns(UserWarning, match="Reference intervals are empty"): metric(np.zeros((0, 2)), np.array([[1, 2], [2, 3]]), trim=False) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == "Reference intervals are empty." - # Now test when 1 interval with trimming + + with pytest.warns(UserWarning, match="Estimated intervals are empty"): metric(np.array([[1, 2], [2, 3]]), np.array([[1, 2]]), trim=True) - assert len(w) == 2 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == "Estimated intervals are empty." - # Check for correct behavior in empty intervals + + with pytest.warns(UserWarning, match="intervals are empty"): empty_intervals = np.zeros((0, 2)) if metric == mir_eval.segment.detection: assert np.allclose(metric(empty_intervals, empty_intervals), 0) else: assert np.all(np.isnan(metric(empty_intervals, empty_intervals))) - # Now test validation function - intervals must be n by 2 - intervals = np.array([1, 2, 3, 4]) - nose.tools.assert_raises(ValueError, metric, intervals, intervals) - # Interval boundaries must be positive - intervals = np.array([[-1, 2], [2, 3]]) - nose.tools.assert_raises(ValueError, metric, intervals, intervals) - # Positive interval durations - intervals = np.array([[2, 1], [2, 3]]) - nose.tools.assert_raises(ValueError, metric, intervals, intervals) - # Check for correct behavior when intervals are the same + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "metric", [mir_eval.segment.detection, mir_eval.segment.deviation] +) +@pytest.mark.parametrize( + "intervals", + [ + # Now test validation function - intervals must be n by 2 + np.array([1, 2, 3, 4]), + # Interval boundaries must be positive + np.array([[-1, 2], [2, 3]]), + # Positive interval durations + np.array([[2, 1], [2, 3]]), + ], +) +def test_segment_boundary_errors(metric, intervals): + metric(intervals, intervals) + + +def test_segment_boundary_detection_perfect(): correct_intervals = np.array([[0, 1], [1, 2]]) - if metric == mir_eval.segment.detection: - assert np.allclose(metric(correct_intervals, correct_intervals), 1) - else: - assert np.allclose(metric(correct_intervals, correct_intervals), 0) - - -def __unit_test_structure_function(metric): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - # Test for warning when empty intervals - score = metric(np.zeros((0, 2)), [], np.zeros((0, 2)), []) - assert len(w) == 2 - assert issubclass(w[0].category, UserWarning) - assert issubclass(w[1].category, UserWarning) - assert str(w[0].message) == "Reference intervals are empty." - assert str(w[1].message) == "Estimated intervals are empty." - # And that the metric is 0 - assert np.allclose(score, 0) - - # Test for non-matching numbers of intervals and labels - intervals = np.array([[2, 1], [2, 3]]) - labels = ['a', 'b', 'c'] - nose.tools.assert_raises(ValueError, metric, intervals, labels, intervals, - labels) - # Now test validation function - intervals must be n by 2 - intervals = np.arange(4) - labels = ['a', 'b', 'c', 'd'] - nose.tools.assert_raises(ValueError, metric, intervals, labels, intervals, - labels) - # Interval boundaries must be positive - intervals = np.array([[-1, 2], [2, 3]]) - nose.tools.assert_raises(ValueError, metric, intervals, labels, intervals, - labels) - # Positive interval durations - intervals = np.array([[2, 1], [2, 3]]) - labels = ['a', 'b'] - nose.tools.assert_raises(ValueError, metric, intervals, labels, intervals, - labels) - # Number of intervals must match number of labels - labels = ['a'] - nose.tools.assert_raises(ValueError, metric, intervals, labels, intervals, - labels) - # Intervals must start at 0 - intervals = np.array([[1, 2], [2, 3]]) - labels = ['a', 'b'] - nose.tools.assert_raises(ValueError, metric, intervals, labels, intervals, - labels) - # End times must match + assert np.allclose( + mir_eval.segment.detection(correct_intervals, correct_intervals), 1 + ) + + +def test_segment_boundary_deviation_perfect(): + correct_intervals = np.array([[0, 1], [1, 2]]) + assert np.allclose( + mir_eval.segment.deviation(correct_intervals, correct_intervals), 0 + ) + + +@pytest.mark.parametrize( + "metric", + [ + mir_eval.segment.pairwise, + mir_eval.segment.rand_index, + mir_eval.segment.ari, + mir_eval.segment.mutual_information, + mir_eval.segment.nce, + mir_eval.segment.vmeasure, + ], +) +def test_segment_structure_empty(metric): + with pytest.warns(UserWarning, match="Reference intervals are empty"): + metric(np.zeros((0, 2)), [], np.array([[0, 1]]), ["foo"]) + + with pytest.warns(UserWarning, match="Estimated intervals are empty"): + metric(np.array([[0, 1]]), ["foo"], np.zeros((0, 2)), []) + + with pytest.warns(UserWarning, match="intervals are empty"): + empty_intervals = np.zeros((0, 2)) + assert np.allclose(metric(empty_intervals, [], empty_intervals, []), 0) + + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "metric", + [ + mir_eval.segment.pairwise, + mir_eval.segment.rand_index, + mir_eval.segment.ari, + mir_eval.segment.mutual_information, + mir_eval.segment.nce, + mir_eval.segment.vmeasure, + ], +) +@pytest.mark.parametrize( + "intervals, labels", + [ + # Test for non-matching numbers of intervals and labels + (np.array([[2, 1], [2, 3]]), ["a", "b", "c"]), + # Now test validation function - intervals must be n by 2 + (np.arange(4), ["a", "b", "c", "d"]), + # Interval boundaries must be positive + (np.array([[-1, 2], [2, 3]]), ["a", "b"]), + # Positive interval durations + (np.array([[2, 1], [2, 3]]), ["a", "b"]), + # Number of intervals must match number of labels + (np.array([[2, 1], [2, 3]]), ["a"]), + # Intervals must start at 0 + (np.array([[1, 2], [2, 3]]), ["a", "b"]), + ], +) +def test_segment_structure_fail(metric, intervals, labels): + metric(intervals, labels, intervals, labels) + + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "metric", + [ + mir_eval.segment.pairwise, + mir_eval.segment.rand_index, + mir_eval.segment.ari, + mir_eval.segment.mutual_information, + mir_eval.segment.nce, + mir_eval.segment.vmeasure, + ], +) +def test_segment_structure_end_mismatch(metric): reference_intervals = np.array([[0, 1], [1, 2]]) estimated_intervals = np.array([[0, 1], [1, 3]]) - nose.tools.assert_raises(ValueError, metric, reference_intervals, labels, - estimated_intervals, labels) - # Check for correct output when input is the same - estimated_intervals = reference_intervals + labels = ["a", "b"] + metric(reference_intervals, labels, estimated_intervals, labels) + + +@pytest.mark.parametrize( + "metric", + [ + mir_eval.segment.pairwise, + mir_eval.segment.rand_index, + mir_eval.segment.ari, + mir_eval.segment.mutual_information, + mir_eval.segment.nce, + mir_eval.segment.vmeasure, + ], +) +def test_segment_structure_perfect(metric): + reference_intervals = np.array([[0, 1], [1, 2]]) + estimated_intervals = np.array([[0, 1], [1, 2]]) + labels = ["a", "b"] if metric == mir_eval.segment.mutual_information: - assert np.allclose(metric(reference_intervals, labels, - estimated_intervals, labels), - [np.log(2), 1, 1]) + assert np.allclose( + metric(reference_intervals, labels, estimated_intervals, labels), + [np.log(2), 1, 1], + ) else: - assert np.allclose(metric(reference_intervals, labels, - estimated_intervals, labels), 1) - + assert np.allclose( + metric(reference_intervals, labels, estimated_intervals, labels), 1 + ) -def __check_score(sco_f, metric, score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) +@pytest.mark.parametrize("segment_data", file_sets, indirect=True) +def test_segment_functions(segment_data): + ref_intervals, ref_labels, est_intervals, est_labels, expected_scores = segment_data -def __unit_test_permuted_segments(sco_f, ref_int, ref_lab, - est_int, est_lab, scores): - # Test for issue #202 - - # Generate a random permutation of the reference segments - idx = np.random.permutation(np.arange(len(ref_int))) + # Compute scores + scores = mir_eval.segment.evaluate( + ref_intervals, ref_labels, est_intervals, est_labels + ) + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) - perm_int = ref_int[idx] - perm_lab = [ref_lab[_] for _ in idx] - perm_scores = mir_eval.segment.evaluate(perm_int, perm_lab, - est_int, est_lab) +@pytest.mark.parametrize("segment_data", file_sets, indirect=True) +def test_segment_functions_permuted(segment_data): + ref_intervals, ref_labels, est_intervals, est_labels, expected_scores = segment_data + # Also check with permuted references + idx = np.random.permutation(np.arange(len(ref_intervals))) + perm_int = ref_intervals[idx] + perm_lab = [ref_labels[_] for _ in idx] + scores = mir_eval.segment.evaluate(perm_int, perm_lab, est_intervals, est_labels) + assert scores.keys() == expected_scores.keys() for metric in scores: - __check_score(sco_f, metric, perm_scores[metric], scores[metric]) - - -def test_segment_functions(): - # Load in all files in the same order - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - assert len(ref_files) == len(est_files) == len(sco_files) > 0 - - # Unit tests for boundary - for metric in [mir_eval.segment.detection, - mir_eval.segment.deviation]: - yield (__unit_test_boundary_function, metric) - # And structure - for metric in [mir_eval.segment.pairwise, - mir_eval.segment.rand_index, - mir_eval.segment.ari, - mir_eval.segment.mutual_information, - mir_eval.segment.nce, - mir_eval.segment.vmeasure]: - yield (__unit_test_structure_function, metric) - # Regression tests - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_scores = json.load(f) - # Load in an example segmentation annotation - ref_intervals, ref_labels = mir_eval.io.load_labeled_intervals(ref_f) - # Load in an example segmentation tracker output - est_intervals, est_labels = mir_eval.io.load_labeled_intervals(est_f) - - # Compute scores - scores = mir_eval.segment.evaluate(ref_intervals, ref_labels, - est_intervals, est_labels) - # Compare them - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, scores[metric], - expected_scores[metric]) - - yield (__unit_test_permuted_segments, sco_f, - ref_intervals, ref_labels, - est_intervals, est_labels, scores) + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) diff --git a/tests/test_separation.py b/tests/test_separation.py index a9ce4e61..4ad7f553 100644 --- a/tests/test_separation.py +++ b/tests/test_separation.py @@ -1,32 +1,72 @@ -''' +""" unit tests for mir_eval.separation load randomly generated source and estimated source signals and the output from BSS_eval MATLAB implementation, make sure the results from mir_eval numerically match. -''' +""" import numpy as np import mir_eval import glob -import nose.tools +import pytest import json import os -import warnings A_TOL = 1e-2 -REF_GLOB = 'data/separation/ref*' -EST_GLOB = 'data/separation/est*' -SCORES_GLOB = 'data/separation/output*.json' +REF_GLOB = "data/separation/ref*" +EST_GLOB = "data/separation/est*" +SCORES_GLOB = "data/separation/output*.json" + +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) + +assert len(ref_files) == len(est_files) == len(sco_files) > 0 +file_sets = list(zip(ref_files, est_files, sco_files)) + + +@pytest.fixture +def separation_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_results = json.load(f) + expected_sources = expected_results["Sources"] + expected_frames = expected_results["Framewise"] + expected_images = expected_results["Images"] + expected_image_frames = expected_results["Images Framewise"] + + # Load in example source separation data + ref_sources = __load_and_stack_wavs(ref_f) + est_sources = __load_and_stack_wavs(est_f) + + # Test inference for single source passed as single dimensional array + if ref_sources.shape[0] == 1 and est_sources.shape[0] == 1: + ref_sources = ref_sources[0] + est_sources = est_sources[0] + + return ( + ref_sources, + est_sources, + expected_sources, + expected_frames, + expected_images, + expected_image_frames, + ) + + +@pytest.fixture(autouse=True) +def seed_rng(): + # Seed the RNG before each test run + np.random.seed(1999) def __load_and_stack_wavs(directory): - ''' Load all wavs in a directory and stack them vertically into a matrix - ''' + """Load all wavs in a directory and stack them vertically into a matrix""" stacked_audio_data = [] global_fs = None - for f in sorted(glob.glob(os.path.join(directory, '*.wav'))): + for f in sorted(glob.glob(os.path.join(directory, "*.wav"))): audio_data, fs = mir_eval.io.load_wav(f) assert global_fs is None or fs == global_fs global_fs = fs @@ -35,14 +75,14 @@ def __load_and_stack_wavs(directory): def __generate_multichannel(mono_sig, nchan=2, gain=1.0, reverse=False): - ''' Turn a single channel (ie. mono) audio sample into a multichannel + """Turn a single channel (ie. mono) audio sample into a multichannel (e.g. stereo) Note: to achieve channels of silence pass gain=0 - ''' + """ # add the channels dimension input_3d = np.atleast_3d(mono_sig) # get the desired number of channels - stackin = [input_3d]*nchan + stackin = [input_3d] * nchan # apply the gain to the new channels stackin[1:] = np.multiply(gain, stackin[1:]) if reverse: @@ -51,61 +91,88 @@ def __generate_multichannel(mono_sig, nchan=2, gain=1.0, reverse=False): return np.dstack(stackin) -def __check_score(sco_f, metric, score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) - - -def __unit_test_empty_input(metric): - if (metric == mir_eval.separation.bss_eval_sources or - metric == mir_eval.separation.bss_eval_images): +@pytest.mark.parametrize( + "metric", + [ + mir_eval.separation.bss_eval_sources, + mir_eval.separation.bss_eval_sources_framewise, + mir_eval.separation.bss_eval_images, + mir_eval.separation.bss_eval_images_framewise, + ], +) +def test_empty_input(metric): + if ( + metric == mir_eval.separation.bss_eval_sources + or metric == mir_eval.separation.bss_eval_images + ): args = [np.array([]), np.array([])] - elif (metric == mir_eval.separation.bss_eval_sources_framewise or - metric == mir_eval.separation.bss_eval_images_framewise): + elif ( + metric == mir_eval.separation.bss_eval_sources_framewise + or metric == mir_eval.separation.bss_eval_images_framewise + ): args = [np.array([]), np.array([]), 40, 20] - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + + with pytest.warns(UserWarning, match="is empty") as record: + # First, test for a warning on empty audio data metric(*args) - assert len(w) == 2 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == ("estimated_sources is empty, " - "should be of size (nsrc, nsample). " - "sdr, sir, sar, and perm will all be " - "empty np.ndarrays") # And that the metric returns empty arrays assert np.allclose(metric(*args), np.array([])) + assert "reference_sources is empty" in str(record[0].message) + assert "estimated_sources is empty" in str(record[1].message) -def __unit_test_silent_input(metric): + +@pytest.mark.parametrize( + "metric", + [ + mir_eval.separation.bss_eval_sources, + mir_eval.separation.bss_eval_sources_framewise, + mir_eval.separation.bss_eval_images, + mir_eval.separation.bss_eval_images_framewise, + ], +) +def test_silent_input(metric): # Test for error when there is a silent reference/estimated source - if (metric == mir_eval.separation.bss_eval_images or - metric == mir_eval.separation.bss_eval_images_framewise): - ref_sources = np.vstack((np.zeros((1, 100, 2)), - np.random.random_sample((2, 100, 2)))) - est_sources = np.vstack((np.zeros((1, 100, 2)), - np.random.random_sample((2, 100, 2)))) - else: - ref_sources = np.vstack((np.zeros(100), - np.random.random_sample((2, 100)))) - est_sources = np.vstack((np.zeros(100), - np.random.random_sample((2, 100)))) - if (metric == mir_eval.separation.bss_eval_sources or - metric == mir_eval.separation.bss_eval_images): - nose.tools.assert_raises(ValueError, metric, ref_sources[:2], - est_sources[1:]) - nose.tools.assert_raises(ValueError, metric, ref_sources[1:], - est_sources[:2]) - elif (metric == mir_eval.separation.bss_eval_sources_framewise or - metric == mir_eval.separation.bss_eval_images_framewise): - nose.tools.assert_raises(ValueError, metric, ref_sources[:2], - est_sources[1:], 40, 20) - nose.tools.assert_raises(ValueError, metric, ref_sources[1:], - est_sources[:2], 40, 20) + if ( + metric == mir_eval.separation.bss_eval_images + or metric == mir_eval.separation.bss_eval_images_framewise + ): + ref_sources = np.vstack( + (np.zeros((1, 100, 2)), np.random.random_sample((2, 100, 2))) + ) + est_sources = np.vstack( + (np.zeros((1, 100, 2)), np.random.random_sample((2, 100, 2))) + ) else: - raise ValueError('Unknown metric {}'.format(metric)) - - -def __unit_test_partial_silence(metric): + ref_sources = np.vstack((np.zeros(100), np.random.random_sample((2, 100)))) + est_sources = np.vstack((np.zeros(100), np.random.random_sample((2, 100)))) + if ( + metric == mir_eval.separation.bss_eval_sources + or metric == mir_eval.separation.bss_eval_images + ): + with pytest.raises(ValueError): + metric(ref_sources[:2], est_sources[1:]) + with pytest.raises(ValueError): + metric(ref_sources[1:], est_sources[:2]) + elif ( + metric == mir_eval.separation.bss_eval_sources_framewise + or metric == mir_eval.separation.bss_eval_images_framewise + ): + with pytest.raises(ValueError): + metric(ref_sources[:2], est_sources[1:], 40, 20) + with pytest.raises(ValueError): + metric(ref_sources[1:], est_sources[:2], 40, 20) + + +@pytest.mark.parametrize( + "metric", + [ + mir_eval.separation.bss_eval_sources_framewise, + mir_eval.separation.bss_eval_images_framewise, + ], +) +def test_partial_silence(metric): # Test for a full window of silence in reference/estimated source if metric == mir_eval.separation.bss_eval_sources_framewise: silence = np.zeros((2, 20)) @@ -113,15 +180,13 @@ def __unit_test_partial_silence(metric): elif metric == mir_eval.separation.bss_eval_images_framewise: silence = np.zeros((2, 20, 2)) sound = np.random.random_sample((2, 20, 2)) - else: - raise ValueError('Unknown metric {}'.format(metric)) # test with silence in the reference - results = metric(np.concatenate((sound, silence, sound), - axis=1), - np.concatenate((sound, sound, sound), - axis=1), - window=10, - hop=10) + results = metric( + np.concatenate((sound, silence, sound), axis=1), + np.concatenate((sound, sound, sound), axis=1), + window=10, + hop=10, + ) for measure in results: for idx, source in enumerate(measure): if idx < 2 or idx > 3: @@ -129,14 +194,14 @@ def __unit_test_partial_silence(metric): elif idx < 4: assert np.isnan(source[idx]) else: - raise ValueError('Testing error in partial silence test') + raise ValueError("Testing error in partial silence test") # test with silence in the estimate - results = metric(np.concatenate((sound, sound, sound), - axis=1), - np.concatenate((sound, silence, sound), - axis=1), - window=10, - hop=10) + results = metric( + np.concatenate((sound, sound, sound), axis=1), + np.concatenate((sound, silence, sound), axis=1), + window=10, + hop=10, + ) for measure in results: for idx, source in enumerate(measure): if idx < 2 or idx > 3: @@ -144,56 +209,96 @@ def __unit_test_partial_silence(metric): elif idx < 4: assert np.isnan(source[idx]) else: - raise ValueError('Testing error in partial silence test') - - -def __unit_test_incompatible_shapes(metric): + raise ValueError("Testing error in partial silence test") + + +@pytest.mark.parametrize( + "metric", + [ + mir_eval.separation.bss_eval_sources, + mir_eval.separation.bss_eval_sources_framewise, + mir_eval.separation.bss_eval_images, + mir_eval.separation.bss_eval_images_framewise, + ], +) +def test_incompatible_shapes(metric): # Test for error when shape is different - if (metric == mir_eval.separation.bss_eval_images or - metric == mir_eval.separation.bss_eval_images_framewise): + if ( + metric == mir_eval.separation.bss_eval_images + or metric == mir_eval.separation.bss_eval_images_framewise + ): sources_4 = np.random.random_sample((4, 100, 2)) sources_3 = np.random.random_sample((3, 100, 2)) sources_4_chan = np.random.random_sample((4, 100, 3)) else: sources_4 = np.random.random_sample((4, 100)) sources_3 = np.random.random_sample((3, 100)) - if (metric == mir_eval.separation.bss_eval_sources or - metric == mir_eval.separation.bss_eval_images): + if ( + metric == mir_eval.separation.bss_eval_sources + or metric == mir_eval.separation.bss_eval_images + ): args1 = [sources_3, sources_4] args2 = [sources_4, sources_3] - elif (metric == mir_eval.separation.bss_eval_sources_framewise or - metric == mir_eval.separation.bss_eval_images_framewise): + elif ( + metric == mir_eval.separation.bss_eval_sources_framewise + or metric == mir_eval.separation.bss_eval_images_framewise + ): args1 = [sources_3, sources_4, 40, 20] args2 = [sources_4, sources_3, 40, 20] - else: - raise ValueError('Unknown metric {}'.format(metric)) - nose.tools.assert_raises(ValueError, metric, *args1) - nose.tools.assert_raises(ValueError, metric, *args2) - if (metric == mir_eval.separation.bss_eval_images or - metric == mir_eval.separation.bss_eval_images_framewise): - nose.tools.assert_raises(ValueError, metric, sources_4, sources_4_chan) - - -def __unit_test_too_many_sources(metric): + with pytest.raises(ValueError): + metric(*args1) + with pytest.raises(ValueError): + metric(*args2) + if ( + metric == mir_eval.separation.bss_eval_images + or metric == mir_eval.separation.bss_eval_images_framewise + ): + with pytest.raises(ValueError): + metric(sources_4, sources_4_chan) + + +@pytest.mark.parametrize( + "metric", + [ + mir_eval.separation.bss_eval_sources, + mir_eval.separation.bss_eval_sources_framewise, + mir_eval.separation.bss_eval_images, + mir_eval.separation.bss_eval_images_framewise, + ], +) +def test_too_many_sources(metric): # Test for error when too many sources or references are provided - many_sources = np.random.random_sample((mir_eval.separation.MAX_SOURCES*2, - 400)) + many_sources = np.random.random_sample((mir_eval.separation.MAX_SOURCES * 2, 400)) if metric == mir_eval.separation.bss_eval_sources: - nose.tools.assert_raises(ValueError, metric, many_sources, - many_sources) + with pytest.raises(ValueError): + metric(many_sources, many_sources) elif metric == mir_eval.separation.bss_eval_sources_framewise: - nose.tools.assert_raises(ValueError, metric, many_sources, - many_sources, 40, 20) - - -def __unit_test_too_many_dimensions(metric): + with pytest.raises(ValueError): + metric(many_sources, many_sources, 40, 20) + + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "metric", + [ + mir_eval.separation.bss_eval_sources, + mir_eval.separation.bss_eval_sources_framewise, + mir_eval.separation.bss_eval_images, + mir_eval.separation.bss_eval_images_framewise, + ], +) +def test_too_many_dimensions(metric): # Test for detection of too high dimensioned images ref_sources = np.random.random_sample((4, 100, 2, 3)) est_sources = np.random.random_sample((4, 100, 2, 3)) - nose.tools.assert_raises(ValueError, metric, ref_sources, est_sources) + metric(ref_sources, est_sources) -def __unit_test_default_permutation(metric): +@pytest.mark.parametrize( + "metric", + [mir_eval.separation.bss_eval_sources, mir_eval.separation.bss_eval_images], +) +def test_default_permutation(metric): # Test for default permutation matrix when not computing permutation if metric == mir_eval.separation.bss_eval_sources: ref_sources = np.random.random_sample((4, 100)) @@ -201,13 +306,18 @@ def __unit_test_default_permutation(metric): elif metric == mir_eval.separation.bss_eval_images: ref_sources = np.random.random_sample((4, 100, 2)) est_sources = np.random.random_sample((4, 100, 2)) - else: - raise ValueError('Unknown metric {}'.format(metric)) results = metric(ref_sources, est_sources, compute_permutation=False) assert np.array_equal(results[-1], np.asarray([0, 1, 2, 3])) -def __unit_test_framewise_small_window(metric): +@pytest.mark.parametrize( + "metric", + [ + mir_eval.separation.bss_eval_sources_framewise, + mir_eval.separation.bss_eval_images_framewise, + ], +) +def test_framewise_small_window(metric): # Test for invalid win/hop parameter detection if metric == mir_eval.separation.bss_eval_sources_framewise: ref_sources = np.random.random_sample((4, 100)) @@ -217,122 +327,116 @@ def __unit_test_framewise_small_window(metric): ref_sources = np.random.random_sample((4, 100, 2)) est_sources = np.random.random_sample((4, 100, 2)) comparison_fcn = mir_eval.separation.bss_eval_images - else: - raise ValueError('Unknown metric {}'.format(metric)) + # Test with window larger than source length - assert np.allclose(np.squeeze(metric(ref_sources, - est_sources, - window=120, - hop=20)), - comparison_fcn(ref_sources, est_sources, False), - atol=A_TOL) + assert np.allclose( + np.squeeze(metric(ref_sources, est_sources, window=120, hop=20)), + comparison_fcn(ref_sources, est_sources, False), + atol=A_TOL, + ) # Test with hop larger than source length - assert np.allclose(np.squeeze(metric(ref_sources, - est_sources, - window=20, - hop=120)), - comparison_fcn(ref_sources, est_sources, False), - atol=A_TOL) - - -def test_separation_functions(): - # Load in all files in the same order - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - assert len(ref_files) == len(est_files) == len(sco_files) > 0 - - # Unit tests - for metric in [mir_eval.separation.bss_eval_sources, - mir_eval.separation.bss_eval_sources_framewise, - mir_eval.separation.bss_eval_images, - mir_eval.separation.bss_eval_images_framewise]: - yield (__unit_test_empty_input, metric) - yield (__unit_test_silent_input, metric) - yield (__unit_test_incompatible_shapes, metric) - yield (__unit_test_too_many_sources, metric) - yield (__unit_test_too_many_dimensions, metric) - for metric in [mir_eval.separation.bss_eval_sources, - mir_eval.separation.bss_eval_images]: - yield (__unit_test_default_permutation, metric) - for metric in [mir_eval.separation.bss_eval_sources_framewise, - mir_eval.separation.bss_eval_images_framewise]: - yield (__unit_test_framewise_small_window, metric) - yield (__unit_test_partial_silence, metric) - # Regression tests - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_results = json.load(f) - expected_sources = expected_results['Sources'] - expected_frames = expected_results['Framewise'] - expected_images = expected_results['Images'] - expected_image_frames = expected_results['Images Framewise'] - # Load in example source separation data - ref_sources = __load_and_stack_wavs(ref_f) - est_sources = __load_and_stack_wavs(est_f) - # Test inference for single source passed as single dimensional array - if ref_sources.shape[0] == 1 and est_sources.shape[0] == 1: - ref_sources = ref_sources[0] - est_sources = est_sources[0] - - # Compute scores - scores = mir_eval.separation.evaluate( - ref_sources, est_sources, - window=expected_frames['win'], hop=expected_frames['hop'] - ) - # Compare them - for metric in scores: - if 'Sources - ' in metric: - test_data_name = metric.replace('Sources - ', '') - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, scores[metric], - expected_sources[test_data_name]) - elif 'Sources Frames - ' in metric: - test_data_name = metric.replace('Sources Frames - ', '') - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, scores[metric], - expected_frames[test_data_name]) - - # Compute scores with images - ref_images = __generate_multichannel(ref_sources, - expected_images['nchan']) - est_images = __generate_multichannel(est_sources, - expected_images['nchan'], - expected_images['gain'], - expected_images['reverse']) - image_scores = mir_eval.separation.evaluate( - ref_images, est_images - ) - # Compare them - for metric in image_scores: - if 'Images - ' in metric: - test_data_name = metric.replace('Images - ', '') - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, image_scores[metric], - expected_images[test_data_name]) - - # Compute scores with images framewise - ref_images = __generate_multichannel(ref_sources, - expected_image_frames['nchan']) - est_images = __generate_multichannel(est_sources, - expected_image_frames['nchan'], - expected_image_frames['gain'], - expected_image_frames['reverse']) - imageframe_scores = mir_eval.separation.evaluate( - ref_images, est_images, - window=expected_image_frames['win'], - hop=expected_image_frames['hop'] - ) - # Compare them - for metric in imageframe_scores: - if 'Images Frames - ' in metric: - test_data_name = metric.replace('Images Frames - ', '') - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, sco_f, metric, imageframe_scores[metric], - expected_image_frames[test_data_name]) + assert np.allclose( + np.squeeze(metric(ref_sources, est_sources, window=20, hop=120)), + comparison_fcn(ref_sources, est_sources, False), + atol=A_TOL, + ) + + +@pytest.mark.parametrize("separation_data", file_sets, indirect=True) +def test_separation_functions(separation_data): + ( + ref_sources, + est_sources, + expected_sources, + expected_frames, + expected_images, + expected_image_frames, + ) = separation_data + + # Compute scores + scores = mir_eval.separation.evaluate( + ref_sources, + est_sources, + window=expected_frames["win"], + hop=expected_frames["hop"], + ) + + # Compare them + for key in scores: + if "Sources - " in key: + test_data_name = key.replace("Sources - ", "") + assert np.allclose( + scores[key], expected_sources[test_data_name], atol=A_TOL + ) + elif "Sources Frames - " in key: + test_data_name = key.replace("Sources Frames - ", "") + assert np.allclose(scores[key], expected_frames[test_data_name], atol=A_TOL) + + +@pytest.mark.parametrize("separation_data", file_sets, indirect=True) +def test_separation_images(separation_data): + ( + ref_sources, + est_sources, + expected_sources, + expected_frames, + expected_images, + expected_image_frames, + ) = separation_data + # Compute scores with images + ref_images = __generate_multichannel(ref_sources, expected_images["nchan"]) + est_images = __generate_multichannel( + est_sources, + expected_images["nchan"], + expected_images["gain"], + expected_images["reverse"], + ) + image_scores = mir_eval.separation.evaluate(ref_images, est_images) + # Compare them + for key in image_scores: + if "Images - " in key: + test_data_name = key.replace("Images - ", "") + assert np.allclose( + image_scores[key], expected_images[test_data_name], atol=A_TOL + ) + + +@pytest.mark.parametrize("separation_data", file_sets, indirect=True) +def test_separation_images_framewise(separation_data): + ( + ref_sources, + est_sources, + expected_sources, + expected_frames, + expected_images, + expected_image_frames, + ) = separation_data + + # Compute scores with images framewise + ref_images = __generate_multichannel(ref_sources, expected_image_frames["nchan"]) + est_images = __generate_multichannel( + est_sources, + expected_image_frames["nchan"], + expected_image_frames["gain"], + expected_image_frames["reverse"], + ) + imageframe_scores = mir_eval.separation.evaluate( + ref_images, + est_images, + window=expected_image_frames["win"], + hop=expected_image_frames["hop"], + ) + # Compare them + for key in imageframe_scores: + if "Images Frames - " in key: + test_data_name = key.replace("Images Frames - ", "") + assert np.allclose( + imageframe_scores[key], + expected_image_frames[test_data_name], + atol=A_TOL, + ) # Catch a few exceptions in the evaluate function image_scores = mir_eval.separation.evaluate(ref_images, est_images) # make sure sources is not being evaluated on images - assert 'Sources - Source to Distortion' not in image_scores + assert "Sources - Source to Distortion" not in image_scores diff --git a/tests/test_sonify.py b/tests/test_sonify.py index ecf8596e..f4afeb74 100644 --- a/tests/test_sonify.py +++ b/tests/test_sonify.py @@ -1,77 +1,90 @@ """ Unit tests for sonification methods """ +import pytest import mir_eval import numpy as np import scipy -def test_clicks(): +@pytest.mark.parametrize("times", [np.array([1.0]), np.arange(10.0)]) +@pytest.mark.parametrize("fs", [8000, 44100]) +def test_clicks(times, fs): # Test output length for a variety of parameter settings - for times in [np.array([1.]), np.arange(10)*1.]: - for fs in [8000, 44100]: - click_signal = mir_eval.sonify.clicks(times, fs) - assert len(click_signal) == times.max()*fs + int(fs*.1) + 1 - click_signal = mir_eval.sonify.clicks(times, fs, length=1000) - assert len(click_signal) == 1000 - click_signal = mir_eval.sonify.clicks( - times, fs, click=np.zeros(1000)) - assert len(click_signal) == times.max()*fs + 1000 + 1 - - -def test_time_frequency(): + click_signal = mir_eval.sonify.clicks(times, fs) + assert len(click_signal) == times.max() * fs + int(fs * 0.1) + 1 + click_signal = mir_eval.sonify.clicks(times, fs, length=1000) + assert len(click_signal) == 1000 + click_signal = mir_eval.sonify.clicks(times, fs, click=np.zeros(1000)) + assert len(click_signal) == times.max() * fs + 1000 + 1 + + +@pytest.mark.parametrize("fs", [8000, 44100]) +def test_time_frequency(fs): # Test length for different inputs - for fs in [8000, 44100]: - signal = mir_eval.sonify.time_frequency( - np.random.standard_normal((100, 1000)), np.arange(1, 101), - np.linspace(0, 10, 1000), fs) - assert len(signal) == 10*fs - signal = mir_eval.sonify.time_frequency( - np.random.standard_normal((100, 1000)), np.arange(1, 101), - np.linspace(0, 10, 1000), fs, length=fs*11) - assert len(signal) == 11*fs - - -def test_chroma(): - for fs in [8000, 44100]: - signal = mir_eval.sonify.chroma( - np.random.standard_normal((12, 1000)), - np.linspace(0, 10, 1000), fs) - assert len(signal) == 10*fs - signal = mir_eval.sonify.chroma( - np.random.standard_normal((12, 1000)), - np.linspace(0, 10, 1000), fs, length=fs*11) - assert len(signal) == 11*fs - - -def test_chords(): - for fs in [8000, 44100]: - intervals = np.array([np.arange(10), np.arange(1, 11)]).T - signal = mir_eval.sonify.chords( - ['C', 'C:maj', 'D:min7', 'E:min', 'C#', 'C', 'C', 'C', 'C', 'C'], - intervals, fs) - assert len(signal) == 10*fs - signal = mir_eval.sonify.chords( - ['C', 'C:maj', 'D:min7', 'E:min', 'C#', 'C', 'C', 'C', 'C', 'C'], - intervals, fs, length=fs*11) - assert len(signal) == 11*fs + signal = mir_eval.sonify.time_frequency( + np.random.standard_normal((100, 1000)), + np.arange(1, 101), + np.linspace(0, 10, 1000), + fs, + ) + assert len(signal) == 10 * fs + signal = mir_eval.sonify.time_frequency( + np.random.standard_normal((100, 1000)), + np.arange(1, 101), + np.linspace(0, 10, 1000), + fs, + length=fs * 11, + ) + assert len(signal) == 11 * fs + + +@pytest.mark.parametrize("fs", [8000, 44100]) +def test_chroma(fs): + signal = mir_eval.sonify.chroma( + np.random.standard_normal((12, 1000)), np.linspace(0, 10, 1000), fs + ) + assert len(signal) == 10 * fs + signal = mir_eval.sonify.chroma( + np.random.standard_normal((12, 1000)), + np.linspace(0, 10, 1000), + fs, + length=fs * 11, + ) + assert len(signal) == 11 * fs + + +@pytest.mark.parametrize("fs", [8000, 44100]) +# FIXME: #371 +@pytest.mark.skip(reason="Skipped until #371 is fixed") +def test_chords(fs): + intervals = np.array([np.arange(10), np.arange(1, 11)]).T + signal = mir_eval.sonify.chords( + ["C", "C:maj", "D:min7", "E:min", "C#", "C", "C", "C", "C", "C"], intervals, fs + ) + assert len(signal) == 10 * fs + signal = mir_eval.sonify.chords( + ["C", "C:maj", "D:min7", "E:min", "C#", "C", "C", "C", "C", "C"], + intervals, + fs, + length=fs * 11, + ) + assert len(signal) == 11 * fs def test_chord_x(): # This test verifies that X sonifies as silence intervals = np.array([[0, 1]]) - signal = mir_eval.sonify.chords(['X'], intervals, 8000) + signal = mir_eval.sonify.chords(["X"], intervals, 8000) assert not np.any(signal), signal def test_pitch_contour(): - # Generate some random pitch fs = 8000 times = np.linspace(0, 5, num=5 * fs, endpoint=True) - noise = scipy.ndimage.gaussian_filter1d(np.random.randn(len(times)), - sigma=256) - freqs = 440.0 * 2.0**(16 * noise) + noise = scipy.ndimage.gaussian_filter1d(np.random.randn(len(times)), sigma=256) + freqs = 440.0 * 2.0 ** (16 * noise) amps = np.linspace(0, 1, num=5 * fs, endpoint=True) # negate a bunch of sequences @@ -88,17 +101,16 @@ def test_pitch_contour(): # which should result in a constant sequence in the output x = mir_eval.sonify.pitch_contour(times, freqs, fs, length=fs * 7) assert len(x) == fs * 7 - assert np.allclose(x[-fs * 2:], x[-fs * 2]) + assert np.allclose(x[-fs * 2 :], x[-fs * 2]) # Test with an explicit duration and a fixed offset # This forces the interpolator to go off the beginning of # the sampling grid, which should result in a constant output x = mir_eval.sonify.pitch_contour(times + 5.0, freqs, fs, length=fs * 7) assert len(x) == fs * 7 - assert np.allclose(x[:fs * 5], x[0]) + assert np.allclose(x[: fs * 5], x[0]) # Test with explicit amplitude - x = mir_eval.sonify.pitch_contour(times, freqs, fs, length=fs * 7, - amplitudes=amps) + x = mir_eval.sonify.pitch_contour(times, freqs, fs, length=fs * 7, amplitudes=amps) assert len(x) == fs * 7 assert np.allclose(x[0], 0) diff --git a/tests/test_tempo.py b/tests/test_tempo.py index 79bdc3e6..e85ddd59 100644 --- a/tests/test_tempo.py +++ b/tests/test_tempo.py @@ -1,28 +1,44 @@ #!/usr/bin/env python -''' +""" Unit tests for mir_eval.tempo -''' -import warnings - +""" import numpy as np import mir_eval -from nose.tools import raises import json import glob +import pytest A_TOL = 1e-12 -def _load_tempi(filename): +REF_GLOB = "data/tempo/ref*.lab" +EST_GLOB = "data/tempo/est*.lab" +SCORES_GLOB = "data/tempo/output*.json" + +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) + +assert len(ref_files) == len(est_files) == len(sco_files) > 0 + +file_sets = list(zip(ref_files, est_files, sco_files)) + - values = mir_eval.io.load_delimited(filename, [float] * 3) +@pytest.fixture +def tempo_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) - return np.concatenate(values[:2]), values[-1][0] + def _load_tempi(filename): + values = mir_eval.io.load_delimited(filename, [float] * 3) + return np.concatenate(values[:2]), values[-1][0] + reference_tempi, ref_weight = _load_tempi(ref_f) + estimated_tempi, _ = _load_tempi(est_f) -def __check_score(sco_f, metric, score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) + return reference_tempi, ref_weight, estimated_tempi, expected_scores def test_zero_tolerance_pass(): @@ -31,16 +47,11 @@ def test_zero_tolerance_pass(): good_est = np.array([120, 180]) zero_tol = 0.0 - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - # Try to produce the warning + with pytest.warns( + UserWarning, match="A tolerance of 0.0 may not lead to the results you expect" + ): mir_eval.tempo.detection(good_ref, good_weight, good_est, tol=zero_tol) - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert str(w[-1].message) == 'A tolerance of 0.0 may ' \ - 'not lead to the results you expect.' - def test_tempo_pass(): good_ref = np.array([60, 120]) @@ -48,65 +59,75 @@ def test_tempo_pass(): good_est = np.array([120, 180]) good_tol = 0.08 - for good_tempo in [np.array([50, 50]), np.array([0, 50]), - np.array([50, 0])]: - yield mir_eval.tempo.detection, good_tempo,\ - good_weight, good_est, good_tol - yield mir_eval.tempo.detection, good_ref,\ - good_weight, good_tempo, good_tol + for good_tempo in [np.array([50, 50]), np.array([0, 50]), np.array([50, 0])]: + mir_eval.tempo.detection(good_tempo, good_weight, good_est, good_tol) + mir_eval.tempo.detection(good_ref, good_weight, good_tempo, good_tol) # allow both estimates to be zero - yield mir_eval.tempo.detection, good_ref,\ - good_weight, np.array([0, 0]), good_tol + mir_eval.tempo.detection(good_ref, good_weight, np.array([0, 0]), good_tol) -def test_tempo_fail(): +@pytest.mark.xfail(raises=ValueError) +def test_tempo_zero_ref(): + # Both references cannot be zero + mir_eval.tempo.detection(np.array([0.0, 0.0]), 0.5, np.array([60, 120])) - @raises(ValueError) - def __test(ref, weight, est, tol): - mir_eval.tempo.detection(ref, weight, est, tol=tol) - good_ref = np.array([60, 120]) - good_weight = 0.5 - good_est = np.array([120, 180]) - good_tol = 0.08 - - for bad_tempo in [np.array([-1, -1]), np.array([-1, 0]), - np.array([-1, 50]), np.array([0, 1, 2]), np.array([0])]: - yield __test, bad_tempo, good_weight, good_est, good_tol - yield __test, good_ref, good_weight, bad_tempo, good_tol - - for bad_weight in [-1, 1.5]: - yield __test, good_ref, bad_weight, good_est, good_tol - - for bad_tol in [-1, 1.5]: - yield __test, good_ref, good_weight, good_est, bad_tol +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize("weight", [-1, 1.5]) +def test_tempo_weight_range(weight): + # Weight needs to be in the range [0, 1] + mir_eval.tempo.detection(np.array([60, 120]), weight, np.array([120, 180])) - # don't allow both references to be zero - yield __test, np.array([0, 0]), good_weight, good_ref, good_tol +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize("tol", [-1, 1.5]) +def test_tempo_tol_range(tol): + # Weight needs to be in the range [0, 1] + mir_eval.tempo.detection(np.array([60, 120]), 0.5, np.array([120, 180]), tol=tol) -def test_tempo_regression(): - REF_GLOB = 'data/tempo/ref*.lab' - EST_GLOB = 'data/tempo/est*.lab' - SCORES_GLOB = 'data/tempo/output*.json' - # Load in all files in the same order - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "bad_tempo", + [ + np.array([-1, -1]), + np.array([-1, 0]), + np.array([-1, 50]), + np.array([0, 1, 2]), + np.array([0]), + ], +) +def test_tempo_fail_bad_reftempo(bad_tempo): + good_ref = np.array([60, 120]) + good_est = np.array([120, 180]) - assert len(ref_files) == len(est_files) == len(sco_files) + mir_eval.tempo.detection(bad_tempo, 0.5, good_est, 0.08) + + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "bad_tempo", + [ + np.array([-1, -1]), + np.array([-1, 0]), + np.array([-1, 50]), + np.array([0, 1, 2]), + np.array([0]), + ], +) +def test_tempo_fail_bad_esttempo(bad_tempo): + good_ref = np.array([60, 120]) + good_est = np.array([120, 180]) - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as fdesc: - expected_scores = json.load(fdesc) + mir_eval.tempo.detection(good_ref, 0.5, bad_tempo, 0.08) - ref_tempi, ref_weight = _load_tempi(ref_f) - est_tempi, _ = _load_tempi(est_f) - scores = mir_eval.tempo.evaluate(ref_tempi, ref_weight, est_tempi) +@pytest.mark.parametrize("tempo_data", file_sets, indirect=True) +def test_tempo_regression(tempo_data): + ref_tempi, ref_weight, est_tempi, expected_scores = tempo_data - for metric in scores: - yield (__check_score, sco_f, metric, scores[metric], - expected_scores[metric]) + scores = mir_eval.tempo.evaluate(ref_tempi, ref_weight, est_tempi) + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) diff --git a/tests/test_transcription.py b/tests/test_transcription.py index 89fcf0f8..1450da28 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -4,28 +4,55 @@ import numpy as np import glob import json -from nose.tools import raises -import warnings +import pytest A_TOL = 1e-12 # Path to the fixture files -REF_GLOB = 'data/transcription/ref*.txt' -EST_GLOB = 'data/transcription/est*.txt' -SCORES_GLOB = 'data/transcription/output*.json' - -REF = np.array([ - [0.100, 0.300, 220.000], - [0.300, 0.400, 246.942], - [0.500, 0.600, 277.183], - [0.550, 0.650, 293.665]]) - -EST = np.array([ - [0.120, 0.290, 225.000], - [0.300, 0.340, 246.942], - [0.500, 0.600, 500.000], - [0.550, 0.600, 293.665], - [0.560, 0.650, 293.665]]) +REF_GLOB = "data/transcription/ref*.txt" +EST_GLOB = "data/transcription/est*.txt" +SCORES_GLOB = "data/transcription/output*.json" + +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) + +assert len(ref_files) == len(est_files) == len(sco_files) > 0 + +file_sets = list(zip(ref_files, est_files, sco_files)) + + +@pytest.fixture +def transcription_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) + # Load in an example segmentation annotation + ref_int, ref_pitch = mir_eval.io.load_valued_intervals(ref_f) + # Load in estimated transcription + est_int, est_pitch = mir_eval.io.load_valued_intervals(est_f) + + return ref_int, ref_pitch, est_int, est_pitch, expected_scores + + +REF = np.array( + [ + [0.100, 0.300, 220.000], + [0.300, 0.400, 246.942], + [0.500, 0.600, 277.183], + [0.550, 0.650, 293.665], + ] +) + +EST = np.array( + [ + [0.120, 0.290, 225.000], + [0.300, 0.340, 246.942], + [0.500, 0.600, 500.000], + [0.550, 0.600, 293.665], + [0.560, 0.650, 293.665], + ] +) SCORES = { "Precision": 0.4, @@ -35,7 +62,7 @@ "Precision_no_offset": 0.6, "Recall_no_offset": 0.75, "F-measure_no_offset": 0.6666666666666665, - "Average_Overlap_Ratio_no_offset": 0.5833333333333333 + "Average_Overlap_Ratio_no_offset": 0.5833333333333333, } ONSET_SCORES = { @@ -52,275 +79,263 @@ def test_match_note_offsets(): - ref_int = REF[:, :2] est_int = EST[:, :2] - matching = ( - mir_eval.transcription.match_note_offsets(ref_int, est_int)) + matching = mir_eval.transcription.match_note_offsets(ref_int, est_int) assert matching == [(0, 0), (2, 2), (3, 3)] def test_match_note_offsets_strict(): - ref_int = REF[:, :2] est_int = EST[:, :2] - matching = ( - mir_eval.transcription.match_note_offsets( - ref_int, est_int, strict=True)) + matching = mir_eval.transcription.match_note_offsets(ref_int, est_int, strict=True) assert matching == [(0, 0), (2, 2), (3, 4)] def test_match_note_onsets(): - ref_int = REF[:, :2] est_int = EST[:, :2] - matching = ( - mir_eval.transcription.match_note_onsets(ref_int, est_int)) + matching = mir_eval.transcription.match_note_onsets(ref_int, est_int) assert matching == [(0, 0), (1, 1), (2, 2), (3, 3)] def test_match_note_onsets_strict(): - ref_int = REF[:, :2] est_int = EST[:, :2] - matching = ( - mir_eval.transcription.match_note_onsets( - ref_int, est_int, strict=True)) + matching = mir_eval.transcription.match_note_onsets(ref_int, est_int, strict=True) assert matching == [(0, 0), (1, 1), (2, 2), (3, 3)] def test_match_notes(): - ref_int, ref_pitch = REF[:, :2], REF[:, 2] est_int, est_pitch = EST[:, :2], EST[:, 2] - matching = ( - mir_eval.transcription.match_notes(ref_int, ref_pitch, est_int, - est_pitch)) + matching = mir_eval.transcription.match_notes( + ref_int, ref_pitch, est_int, est_pitch + ) assert matching == [(0, 0), (3, 3)] - matching = ( - mir_eval.transcription.match_notes(ref_int, ref_pitch, est_int, - est_pitch, offset_ratio=None)) + matching = mir_eval.transcription.match_notes( + ref_int, ref_pitch, est_int, est_pitch, offset_ratio=None + ) assert matching == [(0, 0), (1, 1), (3, 3)] def test_match_notes_strict(): - ref_int, ref_pitch = np.array([[0, 1]]), np.array([100]) est_int, est_pitch = np.array([[0.05, 1]]), np.array([100]) - matching = ( - mir_eval.transcription.match_notes(ref_int, ref_pitch, est_int, - est_pitch, strict=True)) + matching = mir_eval.transcription.match_notes( + ref_int, ref_pitch, est_int, est_pitch, strict=True + ) assert matching == [] def test_precision_recall_f1_overlap(): - # load test data ref_int, ref_pitch = REF[:, :2], REF[:, 2] est_int, est_pitch = EST[:, :2], EST[:, 2] - precision, recall, f_measure, avg_overlap_ratio = ( - mir_eval.transcription.precision_recall_f1_overlap( - ref_int, ref_pitch, est_int, est_pitch)) + ( + precision, + recall, + f_measure, + avg_overlap_ratio, + ) = mir_eval.transcription.precision_recall_f1_overlap( + ref_int, ref_pitch, est_int, est_pitch + ) scores_gen = np.array([precision, recall, f_measure, avg_overlap_ratio]) - scores_exp = np.array([SCORES['Precision'], SCORES['Recall'], - SCORES['F-measure'], - SCORES['Average_Overlap_Ratio']]) + scores_exp = np.array( + [ + SCORES["Precision"], + SCORES["Recall"], + SCORES["F-measure"], + SCORES["Average_Overlap_Ratio"], + ] + ) assert np.allclose(scores_exp, scores_gen, atol=A_TOL) - precision, recall, f_measure, avg_overlap_ratio = ( - mir_eval.transcription.precision_recall_f1_overlap( - ref_int, ref_pitch, est_int, est_pitch, offset_ratio=None)) + ( + precision, + recall, + f_measure, + avg_overlap_ratio, + ) = mir_eval.transcription.precision_recall_f1_overlap( + ref_int, ref_pitch, est_int, est_pitch, offset_ratio=None + ) scores_gen = np.array([precision, recall, f_measure, avg_overlap_ratio]) - scores_exp = np.array([SCORES['Precision_no_offset'], - SCORES['Recall_no_offset'], - SCORES['F-measure_no_offset'], - SCORES['Average_Overlap_Ratio_no_offset']]) + scores_exp = np.array( + [ + SCORES["Precision_no_offset"], + SCORES["Recall_no_offset"], + SCORES["F-measure_no_offset"], + SCORES["Average_Overlap_Ratio_no_offset"], + ] + ) assert np.allclose(scores_exp, scores_gen, atol=A_TOL) -def __check_score(score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) - - def test_onset_precision_recall_f1(): - # load test data ref_int = REF[:, :2] est_int = EST[:, :2] - precision, recall, f_measure = ( - mir_eval.transcription.onset_precision_recall_f1(ref_int, est_int)) + precision, recall, f_measure = mir_eval.transcription.onset_precision_recall_f1( + ref_int, est_int + ) scores_gen = np.array([precision, recall, f_measure]) - scores_exp = np.array([ONSET_SCORES['Onset_Precision'], - ONSET_SCORES['Onset_Recall'], - ONSET_SCORES['Onset_F-measure']]) + scores_exp = np.array( + [ + ONSET_SCORES["Onset_Precision"], + ONSET_SCORES["Onset_Recall"], + ONSET_SCORES["Onset_F-measure"], + ] + ) assert np.allclose(scores_exp, scores_gen, atol=A_TOL) def test_offset_precision_recall_f1(): - # load test data ref_int = REF[:, :2] est_int = EST[:, :2] - precision, recall, f_measure = ( - mir_eval.transcription.offset_precision_recall_f1(ref_int, est_int)) + precision, recall, f_measure = mir_eval.transcription.offset_precision_recall_f1( + ref_int, est_int + ) scores_gen = np.array([precision, recall, f_measure]) - scores_exp = np.array([OFFSET_SCORES['Offset_Precision'], - OFFSET_SCORES['Offset_Recall'], - OFFSET_SCORES['Offset_F-measure']]) + scores_exp = np.array( + [ + OFFSET_SCORES["Offset_Precision"], + OFFSET_SCORES["Offset_Recall"], + OFFSET_SCORES["Offset_F-measure"], + ] + ) assert np.allclose(scores_exp, scores_gen, atol=A_TOL) -def test_regression(): - - # Regression tests - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_scores = json.load(f) - # Load in reference transcription - ref_int, ref_pitch = mir_eval.io.load_valued_intervals(ref_f) - # Load in estimated transcription - est_int, est_pitch = mir_eval.io.load_valued_intervals(est_f) - scores = mir_eval.transcription.evaluate(ref_int, ref_pitch, est_int, - est_pitch) - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, scores[metric], expected_scores[metric]) - - -def test_invalid_pitch(): - - ref_int, ref_pitch = np.array([[0, 1]]), np.array([-100]) - est_int, est_pitch = np.array([[0, 1]]), np.array([100]) +@pytest.mark.parametrize("transcription_data", file_sets, indirect=True) +def test_regression(transcription_data): + ref_int, ref_pitch, est_int, est_pitch, expected_scores = transcription_data - yield (raises(ValueError)(mir_eval.transcription.validate), - ref_int, ref_pitch, est_int, est_pitch) - yield (raises(ValueError)(mir_eval.transcription.validate), - est_int, est_pitch, ref_int, ref_pitch) + scores = mir_eval.transcription.evaluate(ref_int, ref_pitch, est_int, est_pitch) + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL) -def test_inconsistent_int_pitch(): +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "ref_pitch, est_pitch", + [(np.array([-100]), np.array([100])), (np.array([100]), np.array([-100]))], +) +def test_invalid_pitch(ref_pitch, est_pitch): + ref_int = np.array([[0, 1]]) + mir_eval.transcription.validate(ref_int, ref_pitch, ref_int, est_pitch) - ref_int, ref_pitch = np.array([[0, 1], [2, 3]]), np.array([100]) - est_int, est_pitch = np.array([[0, 1]]), np.array([100]) - yield (raises(ValueError)(mir_eval.transcription.validate), - ref_int, ref_pitch, est_int, est_pitch) - yield (raises(ValueError)(mir_eval.transcription.validate), - est_int, est_pitch, ref_int, ref_pitch) +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "ref_int, est_int", + [ + (np.array([[0, 1], [2, 3]]), np.array([[0, 1]])), + (np.array([[0, 1]]), np.array([[0, 1], [2, 3]])), + ], +) +def test_inconsistent_int_pitch(ref_int, est_int): + ref_pitch = np.array([100]) + mir_eval.transcription.validate(ref_int, ref_pitch, est_int, ref_pitch) def test_empty_ref(): + ref_int, ref_pitch = np.empty(shape=(0, 2)), np.array([]) + est_int, est_pitch = np.array([[0, 1]]), np.array([100]) - warnings.resetwarnings() - warnings.simplefilter('always') - with warnings.catch_warnings(record=True) as out: - - ref_int, ref_pitch = np.empty(shape=(0, 2)), np.array([]) - est_int, est_pitch = np.array([[0, 1]]), np.array([100]) - + with pytest.warns(UserWarning, match="Reference notes are empty"): mir_eval.transcription.validate(ref_int, ref_pitch, est_int, est_pitch) - # Make sure that the warning triggered - assert len(out) > 0 - - # And that the category is correct - assert out[0].category is UserWarning - - # And that it says the right thing (roughly) - assert 'empty' in str(out[0].message).lower() - def test_empty_est(): + ref_int, ref_pitch = np.array([[0, 1]]), np.array([100]) + est_int, est_pitch = np.empty(shape=(0, 2)), np.array([]) - warnings.resetwarnings() - warnings.simplefilter('always') - with warnings.catch_warnings(record=True) as out: - - ref_int, ref_pitch = np.array([[0, 1]]), np.array([100]) - est_int, est_pitch = np.empty(shape=(0, 2)), np.array([]) - + with pytest.warns(UserWarning, match="Estimated notes are empty"): mir_eval.transcription.validate(ref_int, ref_pitch, est_int, est_pitch) - # Make sure that the warning triggered - assert len(out) > 0 - - # And that the category is correct - assert out[0].category is UserWarning - - # And that it says the right thing (roughly) - assert 'empty' in str(out[0].message).lower() - +@pytest.mark.filterwarnings("ignore:.*notes are empty") def test_precision_recall_f1_overlap_empty(): - ref_int, ref_pitch = np.empty(shape=(0, 2)), np.array([]) est_int, est_pitch = np.array([[0, 1]]), np.array([100]) - precision, recall, f1, avg_overlap_ratio = ( - mir_eval.transcription.precision_recall_f1_overlap( - ref_int, ref_pitch, est_int, est_pitch)) + ( + precision, + recall, + f1, + avg_overlap_ratio, + ) = mir_eval.transcription.precision_recall_f1_overlap( + ref_int, ref_pitch, est_int, est_pitch + ) assert (precision, recall, f1) == (0, 0, 0) - precision, recall, f1, avg_overlap_ratio = ( - mir_eval.transcription.precision_recall_f1_overlap( - est_int, est_pitch, ref_int, ref_pitch)) + ( + precision, + recall, + f1, + avg_overlap_ratio, + ) = mir_eval.transcription.precision_recall_f1_overlap( + est_int, est_pitch, ref_int, ref_pitch + ) assert (precision, recall, f1) == (0, 0, 0) +@pytest.mark.filterwarnings("ignore:.*notes are empty") def test_onset_precision_recall_f1_empty(): - ref_int = np.empty(shape=(0, 2)) est_int = np.array([[0, 1]]) - precision, recall, f1 = ( - mir_eval.transcription.onset_precision_recall_f1(ref_int, est_int)) + precision, recall, f1 = mir_eval.transcription.onset_precision_recall_f1( + ref_int, est_int + ) assert (precision, recall, f1) == (0, 0, 0) - precision, recall, f1 = ( - mir_eval.transcription.onset_precision_recall_f1(est_int, ref_int)) + precision, recall, f1 = mir_eval.transcription.onset_precision_recall_f1( + est_int, ref_int + ) assert (precision, recall, f1) == (0, 0, 0) +@pytest.mark.filterwarnings("ignore:.*notes are empty") def test_offset_precision_recall_f1_empty(): - ref_int = np.empty(shape=(0, 2)) est_int = np.array([[0, 1]]) - precision, recall, f1 = ( - mir_eval.transcription.offset_precision_recall_f1(ref_int, est_int)) + precision, recall, f1 = mir_eval.transcription.offset_precision_recall_f1( + ref_int, est_int + ) assert (precision, recall, f1) == (0, 0, 0) - precision, recall, f1 = ( - mir_eval.transcription.offset_precision_recall_f1(est_int, ref_int)) + precision, recall, f1 = mir_eval.transcription.offset_precision_recall_f1( + est_int, ref_int + ) assert (precision, recall, f1) == (0, 0, 0) diff --git a/tests/test_transcription_velocity.py b/tests/test_transcription_velocity.py index b48e7194..85f9c8bf 100644 --- a/tests/test_transcription_velocity.py +++ b/tests/test_transcription_velocity.py @@ -1,104 +1,134 @@ +import pytest import mir_eval import numpy as np import glob import json -from nose.tools import raises A_TOL = 1e-12 # Path to the fixture files -REF_GLOB = 'data/transcription_velocity/ref*.txt' -EST_GLOB = 'data/transcription_velocity/est*.txt' -SCORES_GLOB = 'data/transcription_velocity/output*.json' +REF_GLOB = "data/transcription_velocity/ref*.txt" +EST_GLOB = "data/transcription_velocity/est*.txt" +SCORES_GLOB = "data/transcription_velocity/output*.json" +ref_files = sorted(glob.glob(REF_GLOB)) +est_files = sorted(glob.glob(EST_GLOB)) +sco_files = sorted(glob.glob(SCORES_GLOB)) +assert len(ref_files) == len(est_files) == len(sco_files) > 0 -def test_negative_velocity(): +file_sets = list(zip(ref_files, est_files, sco_files)) + + +def _load_transcription_velocity(filename): + """Loader for data in the format start, end, pitch, velocity.""" + starts, ends, pitches, velocities = mir_eval.io.load_delimited( + filename, [float, float, int, int] + ) + # Stack into an interval matrix + intervals = np.array([starts, ends]).T + # return pitches and velocities as np.ndarray + pitches = np.array(pitches) + velocities = np.array(velocities) + return intervals, pitches, velocities + + +@pytest.fixture +def velocity_data(request): + ref_f, est_f, sco_f = request.param + with open(sco_f, "r") as f: + expected_scores = json.load(f) + # Load in reference transcription + ref_int, ref_pitch, ref_vel = _load_transcription_velocity(ref_f) + # Load in estimated transcription + est_int, est_pitch, est_vel = _load_transcription_velocity(est_f) + return (ref_int, ref_pitch, ref_vel), (est_int, est_pitch, est_vel), expected_scores + + +def test_negative_velocity(): good_i, good_p, good_v = np.array([[0, 1]]), np.array([100]), np.array([1]) bad_i, bad_p, bad_v = np.array([[0, 1]]), np.array([100]), np.array([-1]) - yield (raises(ValueError)(mir_eval.transcription_velocity.validate), - bad_i, bad_p, bad_v, good_i, good_p, good_v) - yield (raises(ValueError)(mir_eval.transcription_velocity.validate), - good_i, good_p, good_v, bad_i, bad_p, bad_v) + with pytest.raises(ValueError): + mir_eval.transcription_velocity.validate( + bad_i, bad_p, bad_v, good_i, good_p, good_v + ) + with pytest.raises(ValueError): + mir_eval.transcription_velocity.validate( + good_i, good_p, good_v, bad_i, bad_p, bad_v + ) def test_wrong_shape_velocity(): - good_i, good_p, good_v = np.array([[0, 1]]), np.array([100]), np.array([1]) bad_i, bad_p, bad_v = np.array([[0, 1]]), np.array([100]), np.array([1, 2]) - yield (raises(ValueError)(mir_eval.transcription_velocity.validate), - bad_i, bad_p, bad_v, good_i, good_p, good_v) - yield (raises(ValueError)(mir_eval.transcription_velocity.validate), - good_i, good_p, good_v, bad_i, bad_p, bad_v) + with pytest.raises(ValueError): + mir_eval.transcription_velocity.validate( + bad_i, bad_p, bad_v, good_i, good_p, good_v + ) + with pytest.raises(ValueError): + mir_eval.transcription_velocity.validate( + good_i, good_p, good_v, bad_i, bad_p, bad_v + ) def test_precision_recall_f1_overlap(): # Simple unit test - ref_i = np.array([[0, 1], [.5, .7], [1, 2]]) + ref_i = np.array([[0, 1], [0.5, 0.7], [1, 2]]) ref_p = np.array([100, 110, 80]) ref_v = np.array([10, 90, 110]) - est_i = np.array([[0, 1], [.5, .7], [1, 2]]) + est_i = np.array([[0, 1], [0.5, 0.7], [1, 2]]) est_p = np.array([100, 110, 80]) est_v = np.array([10, 70, 110]) p, r, f, o = mir_eval.transcription_velocity.precision_recall_f1_overlap( - ref_i, ref_p, ref_v, est_i, est_p, est_v) - assert np.allclose((p, r, f, o), (2/3., 2/3., 2/3., 1.)) + ref_i, ref_p, ref_v, est_i, est_p, est_v + ) + assert np.allclose((p, r, f, o), (2 / 3.0, 2 / 3.0, 2 / 3.0, 1.0)) p, r, f, o = mir_eval.transcription_velocity.precision_recall_f1_overlap( - ref_i, ref_p, ref_v, est_i, est_p, est_v, velocity_tolerance=0.3) - assert np.allclose((p, r, f, o), (1., 1., 1., 1.)) + ref_i, ref_p, ref_v, est_i, est_p, est_v, velocity_tolerance=0.3 + ) + assert np.allclose((p, r, f, o), (1.0, 1.0, 1.0, 1.0)) +# Suppressing this warning. We know the notes are empty, that's not the point. +@pytest.mark.filterwarnings("ignore:.*notes are empty") def test_precision_recall_f1_overlap_empty(): good_i, good_p, good_v = np.array([[0, 1]]), np.array([100]), np.array([1]) bad_i, bad_p, bad_v = np.empty((0, 2)), np.array([]), np.array([]) p, r, f, o = mir_eval.transcription_velocity.precision_recall_f1_overlap( - good_i, good_p, good_v, bad_i, bad_p, bad_v) - assert (p, r, f, o) == (0., 0., 0., 0.) + good_i, good_p, good_v, bad_i, bad_p, bad_v + ) + assert (p, r, f, o) == (0.0, 0.0, 0.0, 0.0) p, r, f, o = mir_eval.transcription_velocity.precision_recall_f1_overlap( - bad_i, bad_p, bad_v, good_i, good_p, good_v) - assert (p, r, f, o) == (0., 0., 0., 0.) + bad_i, bad_p, bad_v, good_i, good_p, good_v + ) + assert (p, r, f, o) == (0.0, 0.0, 0.0, 0.0) def test_precision_recall_f1_overlap_no_overlap(): p, r, f, o = mir_eval.transcription_velocity.precision_recall_f1_overlap( - np.array([[1, 2]]), np.array([1]), np.array([1]), - np.array([[3, 4]]), np.array([1]), np.array([1])) - assert (p, r, f, o) == (0., 0., 0., 0.) - - -def __check_score(score, expected_score): - assert np.allclose(score, expected_score, atol=A_TOL) - - -def test_regression(): - - def _load_transcription_velocity(filename): - """Loader for data in the format start, end, pitch, velocity.""" - starts, ends, pitches, velocities = mir_eval.io.load_delimited( - filename, [float, float, int, int]) - # Stack into an interval matrix - intervals = np.array([starts, ends]).T - # return pitches and velocities as np.ndarray - pitches = np.array(pitches) - velocities = np.array(velocities) - return intervals, pitches, velocities - - # Regression tests - ref_files = sorted(glob.glob(REF_GLOB)) - est_files = sorted(glob.glob(EST_GLOB)) - sco_files = sorted(glob.glob(SCORES_GLOB)) - - for ref_f, est_f, sco_f in zip(ref_files, est_files, sco_files): - with open(sco_f, 'r') as f: - expected_scores = json.load(f) - # Load in reference transcription - ref_int, ref_pitch, ref_vel = _load_transcription_velocity(ref_f) - # Load in estimated transcription - est_int, est_pitch, est_vel = _load_transcription_velocity(est_f) - scores = mir_eval.transcription_velocity.evaluate( - ref_int, ref_pitch, ref_vel, est_int, est_pitch, est_vel) - for metric in scores: - # This is a simple hack to make nosetest's messages more useful - yield (__check_score, scores[metric], expected_scores[metric]) + np.array([[1, 2]]), + np.array([1]), + np.array([1]), + np.array([[3, 4]]), + np.array([1]), + np.array([1]), + ) + assert (p, r, f, o) == (0.0, 0.0, 0.0, 0.0) + + +@pytest.mark.parametrize("velocity_data", file_sets, indirect=True) +def test_regression(velocity_data): + ( + (ref_int, ref_pitch, ref_vel), + (est_int, est_pitch, est_vel), + expected_scores, + ) = velocity_data + + scores = mir_eval.transcription_velocity.evaluate( + ref_int, ref_pitch, ref_vel, est_int, est_pitch, est_vel + ) + assert scores.keys() == expected_scores.keys() + for metric in scores: + assert np.allclose(scores[metric], expected_scores[metric], atol=A_TOL), metric diff --git a/tests/test_util.py b/tests/test_util.py index 9e01a70e..a87661e3 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,10 +1,10 @@ -''' Unit tests for utils -''' +""" Unit tests for utils +""" import collections +import pytest import numpy as np -import nose.tools import mir_eval from mir_eval import util @@ -13,30 +13,32 @@ def test_interpolate_intervals(): """Check that an interval set is interpolated properly, with boundaries conditions and out-of-range values. """ - labels = list('abc') + labels = list("abc") intervals = np.array([(n, n + 1.0) for n in range(len(labels))]) time_points = [-1.0, 0.1, 0.9, 1.0, 2.3, 4.0] - expected_ans = ['N', 'a', 'a', 'b', 'c', 'N'] - assert (util.interpolate_intervals(intervals, labels, time_points, 'N') == - expected_ans) + expected_ans = ["N", "a", "a", "b", "c", "N"] + assert ( + util.interpolate_intervals(intervals, labels, time_points, "N") == expected_ans + ) def test_interpolate_intervals_gap(): """Check that an interval set is interpolated properly, with gaps.""" - labels = list('abc') + labels = list("abc") intervals = np.array([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]]) time_points = [0.0, 0.75, 1.25, 1.75, 2.25, 2.75, 3.5] - expected_ans = ['N', 'a', 'N', 'b', 'N', 'c', 'N'] - assert (util.interpolate_intervals(intervals, labels, time_points, 'N') == - expected_ans) + expected_ans = ["N", "a", "N", "b", "N", "c", "N"] + assert ( + util.interpolate_intervals(intervals, labels, time_points, "N") == expected_ans + ) -@nose.tools.raises(ValueError) +@pytest.mark.xfail(raises=ValueError) def test_interpolate_intervals_badtime(): """Check that interpolate_intervals throws an exception if input is unordered. """ - labels = list('abc') + labels = list("abc") intervals = np.array([(n, n + 1.0) for n in range(len(labels))]) time_points = [-1.0, 0.1, 0.9, 0.8, 2.3, 4.0] mir_eval.util.interpolate_intervals(intervals, labels, time_points) @@ -46,64 +48,58 @@ def test_intervals_to_samples(): """Check that an interval set is sampled properly, with boundaries conditions and out-of-range values. """ - labels = list('abc') + labels = list("abc") intervals = np.array([(n, n + 1.0) for n in range(len(labels))]) expected_times = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5] - expected_labels = ['a', 'a', 'b', 'b', 'c', 'c'] + expected_labels = ["a", "a", "b", "b", "c", "c"] result = util.intervals_to_samples( - intervals, labels, offset=0, sample_size=0.5, fill_value='N') + intervals, labels, offset=0, sample_size=0.5, fill_value="N" + ) assert result[0] == expected_times assert result[1] == expected_labels expected_times = [0.25, 0.75, 1.25, 1.75, 2.25, 2.75] - expected_labels = ['a', 'a', 'b', 'b', 'c', 'c'] + expected_labels = ["a", "a", "b", "b", "c", "c"] result = util.intervals_to_samples( - intervals, labels, offset=0.25, sample_size=0.5, fill_value='N') + intervals, labels, offset=0.25, sample_size=0.5, fill_value="N" + ) assert result[0] == expected_times assert result[1] == expected_labels def test_intersect_files(): - """Check that two non-identical yield correct results. - """ - flist1 = ['/a/b/abc.lab', '/c/d/123.lab', '/e/f/xyz.lab'] - flist2 = ['/g/h/xyz.npy', '/i/j/123.txt', '/k/l/456.lab'] + """Check that two non-identical produce correct results.""" + flist1 = ["/a/b/abc.lab", "/c/d/123.lab", "/e/f/xyz.lab"] + flist2 = ["/g/h/xyz.npy", "/i/j/123.txt", "/k/l/456.lab"] sublist1, sublist2 = util.intersect_files(flist1, flist2) - assert sublist1 == ['/e/f/xyz.lab', '/c/d/123.lab'] - assert sublist2 == ['/g/h/xyz.npy', '/i/j/123.txt'] + assert sublist1 == ["/e/f/xyz.lab", "/c/d/123.lab"] + assert sublist2 == ["/g/h/xyz.npy", "/i/j/123.txt"] sublist1, sublist2 = util.intersect_files(flist1[:1], flist2[:1]) assert sublist1 == [] assert sublist2 == [] def test_merge_labeled_intervals(): - """Check that two labeled interval sequences merge correctly. - """ - x_intvs = np.array([ - [0.0, 0.44], - [0.44, 2.537], - [2.537, 4.511], - [4.511, 6.409]]) - x_labels = ['A', 'B', 'C', 'D'] - y_intvs = np.array([ - [0.0, 0.464], - [0.464, 2.415], - [2.415, 4.737], - [4.737, 6.409]]) + """Check that two labeled interval sequences merge correctly.""" + x_intvs = np.array([[0.0, 0.44], [0.44, 2.537], [2.537, 4.511], [4.511, 6.409]]) + x_labels = ["A", "B", "C", "D"] + y_intvs = np.array([[0.0, 0.464], [0.464, 2.415], [2.415, 4.737], [4.737, 6.409]]) y_labels = [0, 1, 2, 3] expected_intvs = [ - [0.0, 0.44], - [0.44, 0.464], + [0.0, 0.44], + [0.44, 0.464], [0.464, 2.415], [2.415, 2.537], [2.537, 4.511], [4.511, 4.737], - [4.737, 6.409]] - expected_x_labels = ['A', 'B', 'B', 'B', 'C', 'D', 'D'] - expected_y_labels = [0, 0, 1, 2, 2, 2, 3] + [4.737, 6.409], + ] + expected_x_labels = ["A", "B", "B", "B", "C", "D", "D"] + expected_y_labels = [0, 0, 1, 2, 2, 2, 3] new_intvs, new_x_labels, new_y_labels = util.merge_labeled_intervals( - x_intvs, x_labels, y_intvs, y_labels) + x_intvs, x_labels, y_intvs, y_labels + ) assert new_x_labels == expected_x_labels assert new_y_labels == expected_y_labels @@ -111,8 +107,8 @@ def test_merge_labeled_intervals(): # Check that invalid inputs raise a ValueError y_intvs[-1, -1] = 10.0 - nose.tools.assert_raises(ValueError, util.merge_labeled_intervals, x_intvs, - x_labels, y_intvs, y_labels) + with pytest.raises(ValueError): + util.merge_labeled_intervals(x_intvs, x_labels, y_intvs, y_labels) def test_boundaries_to_intervals(): @@ -127,19 +123,19 @@ def test_adjust_events(): # Test appending at the end events = np.arange(1, 11) labels = [str(n) for n in range(10)] - new_e, new_l = mir_eval.util.adjust_events(events, labels, 0.0, 11.) - assert new_e[0] == 0. - assert new_l[0] == '__T_MIN' - assert new_e[-1] == 11. - assert new_l[-1] == '__T_MAX' + new_e, new_l = mir_eval.util.adjust_events(events, labels, 0.0, 11.0) + assert new_e[0] == 0.0 + assert new_l[0] == "__T_MIN" + assert new_e[-1] == 11.0 + assert new_l[-1] == "__T_MAX" assert np.all(new_e[1:-1] == events) assert new_l[1:-1] == labels # Test trimming - new_e, new_l = mir_eval.util.adjust_events(events, labels, 0.0, 9.) - assert new_e[0] == 0. - assert new_l[0] == '__T_MIN' - assert new_e[-1] == 9. + new_e, new_l = mir_eval.util.adjust_events(events, labels, 0.0, 9.0) + assert new_e[0] == 0.0 + assert new_l[0] == "__T_MIN" + assert new_e[-1] == 9.0 assert np.all(new_e[1:] == events[:-1]) assert new_l[1:] == labels[:-1] @@ -157,23 +153,23 @@ def test_bipartite_match(): # G = collections.defaultdict(list) - u_set = ['u{:d}'.format(_) for _ in range(10)] - v_set = ['v{:d}'.format(_) for _ in range(len(u_set)+1)] + u_set = ["u{:d}".format(_) for _ in range(10)] + v_set = ["v{:d}".format(_) for _ in range(len(u_set) + 1)] for i, u in enumerate(u_set): - for v in v_set[:-i-1]: + for v in v_set[: -i - 1]: G[v].append(u) matching = util._bipartite_match(G) # Make sure that each u vertex is matched - nose.tools.eq_(len(matching), len(u_set)) + assert len(matching) == len(u_set) # Make sure that there are no duplicate keys lhs = set([k for k in matching]) rhs = set([matching[k] for k in matching]) - nose.tools.eq_(len(matching), len(lhs)) - nose.tools.eq_(len(matching), len(rhs)) + assert len(matching) == len(lhs) + assert len(matching) == len(rhs) # Finally, make sure that all detected edges are present in G for k in matching: @@ -182,44 +178,50 @@ def test_bipartite_match(): def test_outer_distance_mod_n(): - ref = [1., 2., 3.] - est = [1.1, 6., 1.9, 5., 10.] - expected = np.array([ - [0.1, 5., 0.9, 4., 3.], - [0.9, 4., 0.1, 3., 4.], - [1.9, 3., 1.1, 2., 5.]]) + ref = [1.0, 2.0, 3.0] + est = [1.1, 6.0, 1.9, 5.0, 10.0] + expected = np.array( + [ + [0.1, 5.0, 0.9, 4.0, 3.0], + [0.9, 4.0, 0.1, 3.0, 4.0], + [1.9, 3.0, 1.1, 2.0, 5.0], + ] + ) actual = mir_eval.util._outer_distance_mod_n(ref, est) assert np.allclose(actual, expected) - ref = [13., 14., 15.] - est = [1.1, 6., 1.9, 5., 10.] - expected = np.array([ - [0.1, 5., 0.9, 4., 3.], - [0.9, 4., 0.1, 3., 4.], - [1.9, 3., 1.1, 2., 5.]]) + ref = [13.0, 14.0, 15.0] + est = [1.1, 6.0, 1.9, 5.0, 10.0] + expected = np.array( + [ + [0.1, 5.0, 0.9, 4.0, 3.0], + [0.9, 4.0, 0.1, 3.0, 4.0], + [1.9, 3.0, 1.1, 2.0, 5.0], + ] + ) actual = mir_eval.util._outer_distance_mod_n(ref, est) assert np.allclose(actual, expected) def test_match_events(): - ref = [1., 2., 3.] - est = [1.1, 6., 1.9, 5., 10.] + ref = [1.0, 2.0, 3.0] + est = [1.1, 6.0, 1.9, 5.0, 10.0] expected = [(0, 0), (1, 2)] actual = mir_eval.util.match_events(ref, est, 0.5) assert actual == expected - ref = [1., 2., 3., 11.9] - est = [1.1, 6., 1.9, 5., 10., 0.] + ref = [1.0, 2.0, 3.0, 11.9] + est = [1.1, 6.0, 1.9, 5.0, 10.0, 0.0] expected = [(0, 0), (1, 2), (3, 5)] actual = mir_eval.util.match_events( - ref, est, 0.5, distance=mir_eval.util._outer_distance_mod_n) + ref, est, 0.5, distance=mir_eval.util._outer_distance_mod_n + ) assert actual == expected def test_fast_hit_windows(): - - ref = [1., 2., 3.] - est = [1.1, 6., 1.9, 5., 10.] + ref = [1.0, 2.0, 3.0] + est = [1.1, 6.0, 1.9, 5.0, 10.0] ref_fast, est_fast = mir_eval.util._fast_hit_windows(ref, est, 0.5) ref_slow, est_slow = np.where(np.abs(np.subtract.outer(ref, est)) <= 0.5) @@ -228,70 +230,74 @@ def test_fast_hit_windows(): assert np.all(est_fast == est_slow) -def test_validate_intervals(): - # Test for ValueError when interval shape is invalid - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_intervals, - np.array([[1.], [2.5], [5.]])) - # Test for ValueError when times are negative - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_intervals, - np.array([[1., -2.], [2.5, 3.], [5., 6.]])) - # Test for ValueError when duration is zero - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_intervals, - np.array([[1., 2.], [2.5, 2.5], [5., 6.]])) - # Test for ValueError when duration is negative - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_intervals, - np.array([[1., 2.], [2.5, 1.5], [5., 6.]])) - - -def test_validate_events(): - # Test for ValueError when max_time is violated - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_events, np.array([100., 100000.])) - # Test for ValueError when events aren't 1-d arrays - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_events, - np.array([[1., 2.], [3., 4.]])) - # Test for ValueError when event times are not increasing - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_events, - np.array([1., 2., 5., 3.])) - - -def test_validate_frequencies(): - # Test for ValueError when max_freq is violated - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_frequencies, - np.array([100., 100000.]), 5000., 20.) - # Test for ValueError when min_freq is violated - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_frequencies, - np.array([2., 200.]), 5000., 20.) - # Test for ValueError when events aren't 1-d arrays - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_frequencies, - np.array([[100., 200.], [300., 400.]]), 5000., 20.) - # Test for ValueError when allow_negatives is false and negative values - # are passed - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_frequencies, - np.array([[-100., 200.], [300., 400.]]), 5000., 20., - allow_negatives=False) - # Test for ValueError when max_freq is violated and allow_negatives=True - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_frequencies, - np.array([100., -100000.]), 5000., 20., allow_negatives=True) - # Test for ValueError when min_freq is violated and allow_negatives=True - nose.tools.assert_raises( - ValueError, mir_eval.util.validate_frequencies, - np.array([-2., 200.]), 5000., 20., allow_negatives=True) +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "intervals", + [ + # Test for ValueError when interval shape is invalid + np.array([[1.0], [2.5], [5.0]]), + # Test for ValueError when times are negative + np.array([[1.0, -2.0], [2.5, 3.0], [5.0, 6.0]]), + # Test for ValueError when duration is zero + np.array([[1.0, 2.0], [2.5, 2.5], [5.0, 6.0]]), + # Test for ValueError when duration is negative + np.array([[1.0, 2.0], [2.5, 1.5], [5.0, 6.0]]), + ], +) +def test_validate_intervals(intervals): + mir_eval.util.validate_intervals(intervals) + + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "events", + [ + # Test for ValueError when max_time is violated + np.array([100.0, 100000.0]), + # Test for ValueError when events aren't 1-d arrays + np.array([[1.0, 2.0], [3.0, 4.0]]), + # Test for ValueError when event times are not increasing + np.array([1.0, 2.0, 5.0, 3.0]), + ], +) +def test_validate_events(events): + mir_eval.util.validate_events(events) + + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "freqs", + [ + # Test for ValueError when max_freq is violated + np.array([100, 10000]), + # Test for ValueError when min_freq is violated + np.array([2, 200]), + # Test for ValueError when events aren't 1-d arrays + np.array([[100, 200], [300, 400]]), + # Test for ValueError when allow_negatives is false and negative values + # are passed + np.array([-100, 200]), + ], +) +def test_validate_frequencies(freqs): + mir_eval.util.validate_frequencies(freqs, 5000, 20, allow_negatives=False) + + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize( + "freqs", + [ + # Test for ValueError when max_freq is violated and allow_negatives=True + np.array([100, -100000]), + # Test for ValueError when min_freq is violated and allow_negatives=True + np.array([-2, 200]), + ], +) +def test_validate_frequencies_negative(freqs): + mir_eval.util.validate_frequencies(freqs, 5000, 20, allow_negatives=True) def test_has_kwargs(): - def __test(target, f): assert target == mir_eval.util.has_kwargs(f) @@ -310,31 +316,43 @@ def f4(_, **kw): def f5(_=5, **kw): return None - yield __test, False, f1 - yield __test, False, f2 - yield __test, False, f3 - yield __test, True, f4 - yield __test, True, f5 - - -def test_sort_labeled_intervals(): - - def __test_labeled(x, labels, x_true, lab_true): - xs, ls = mir_eval.util.sort_labeled_intervals(x, labels) - - assert np.allclose(xs, x_true) - nose.tools.eq_(ls, lab_true) - - def __test(x, x_true): - xs = mir_eval.util.sort_labeled_intervals(x) - assert np.allclose(xs, x_true) - - x1 = np.asarray([[10, 20], [0, 10]]) - x1_true = np.asarray([[0, 10], [10, 20]]) - labels = ['a', 'b'] - labels_true = ['b', 'a'] - - yield __test_labeled, x1, labels, x1_true, labels_true - yield __test, x1, x1_true - yield __test_labeled, x1_true, labels_true, x1_true, labels_true - yield __test, x1_true, x1_true + assert not mir_eval.util.has_kwargs(f1) + assert not mir_eval.util.has_kwargs(f2) + assert not mir_eval.util.has_kwargs(f3) + assert mir_eval.util.has_kwargs(f4) + assert mir_eval.util.has_kwargs(f5) + + +@pytest.mark.parametrize( + "x,labels,x_true,lab_true", + [ + ( + np.asarray([[10, 20], [0, 10]]), + ["a", "b"], + np.asarray([[0, 10], [10, 20]]), + ["b", "a"], + ), + ( + np.asarray([[0, 10], [10, 20]]), + ["b", "a"], + np.asarray([[0, 10], [10, 20]]), + ["b", "a"], + ), + ], +) +def test_sort_labeled_intervals_with_labels(x, labels, x_true, lab_true): + xs, ls = mir_eval.util.sort_labeled_intervals(x, labels) + assert np.allclose(xs, x_true) + assert ls == lab_true + + +@pytest.mark.parametrize( + "x,x_true", + [ + (np.asarray([[10, 20], [0, 10]]), np.asarray([[0, 10], [10, 20]])), + (np.asarray([[0, 10], [10, 20]]), np.asarray([[0, 10], [10, 20]])), + ], +) +def test_sort_labeled_intervals_without_labels(x, x_true): + xs = mir_eval.util.sort_labeled_intervals(x) + assert np.allclose(xs, x_true)