From cf4807881d45aa3aa3db592cba7419ae7889fe9e Mon Sep 17 00:00:00 2001 From: Justin Salamon Date: Mon, 15 Feb 2016 20:39:24 -0500 Subject: [PATCH] Add strict option, improve offset_hit_matrix calc Add strict option for using <. The default is strict=False, which means using <=. Use a more efficient metho to compute the offset_hit_matric. Rename cmp to cmp_func. --- .gitignore | 36 +++---------------------- mir_eval/transcription.py | 56 ++++++++++++++++++++++++++++++--------- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/.gitignore b/.gitignore index 0c951937..b878b2c7 100644 --- a/.gitignore +++ b/.gitignore @@ -40,36 +40,6 @@ Thumbs.db # pycharm .idea/* -docs/_build/doctrees/environment.pickle -docs/_build/doctrees/index.doctree -docs/_build/html/.buildinfo -docs/_build/html/genindex.html -docs/_build/html/index.html -docs/_build/html/objects.inv -docs/_build/html/py-modindex.html -docs/_build/html/search.html -docs/_build/html/searchindex.js -docs/_build/html/_images/math/55ad58c223214dd725b3a914744b50efa8973baf.png -docs/_build/html/_images/math/63cc48858eff17de24f0e9e1ca2d0459e0fabfe4.png -docs/_build/html/_images/math/950b26ea605ab18bef0159f0aa3092055f7d5603.png -docs/_build/html/_sources/index.txt -docs/_build/html/_static/ajax-loader.gif -docs/_build/html/_static/basic.css -docs/_build/html/_static/comment-bright.png -docs/_build/html/_static/comment-close.png -docs/_build/html/_static/comment.png -docs/_build/html/_static/default.css -docs/_build/html/_static/doctools.js -docs/_build/html/_static/down-pressed.png -docs/_build/html/_static/down.png -docs/_build/html/_static/file.png -docs/_build/html/_static/jquery.js -docs/_build/html/_static/minus.png -docs/_build/html/_static/plus.png -docs/_build/html/_static/pygments.css -docs/_build/html/_static/searchtools.js -docs/_build/html/_static/sidebar.js -docs/_build/html/_static/underscore.js -docs/_build/html/_static/up-pressed.png -docs/_build/html/_static/up.png -docs/_build/html/_static/websupport.js + +# docs +docs/_build/* \ No newline at end of file diff --git a/mir_eval/transcription.py b/mir_eval/transcription.py index 6d93410e..05633edd 100644 --- a/mir_eval/transcription.py +++ b/mir_eval/transcription.py @@ -2,7 +2,7 @@ The aim of a transcription algorithm is to produce a symbolic representation of a recorded piece of music in the form of a set of discrete notes. There are different ways to represent notes symbolically. Here we use the piano-roll -convention, meaning each notes has a start time, a duration (or end time), and +convention, meaning each note has a start time, a duration (or end time), and a single, constant, pitch value. Pitch values can be quantized (e.g. to a semitone grid tuned to 440 Hz), but do not have to be. Also, the transcription can contain the notes of a single instrument or voice (for example the melody), @@ -36,6 +36,18 @@ Salamon, J. (2013). Melody Extraction from Polyphonic Music Signals. Ph.D. thesis, Universitat Pompeu Fabra, Barcelona, Spain, 2013. +Note: two different evaluation scripts have been used in MIREX over the years, +where one uses ``<`` for matching onsets, offsets, and pitch values, whilst +the other uses ``<=`` for these checks. That is, if the distance between two +onsets is exactly equal to the defined threshold (e.g. 0.05), then the former +script would not consider the two notes to have matching onsets, whilst the +latter would. `mir_eval` provides both options: by default the latter +(``<=``) is used, but you can set ``strict=True`` when calling +:func:`mir_eval.transcription.precision_recall_f1()` in which case ``<`` will +be used. The default value (``strict=False``) matches the evaluation code that +was used to produce the results reported on the MIREX website for the "Su" +dataset in 2015. + Conventions ----------- @@ -102,13 +114,14 @@ def validate(ref_intervals, ref_pitches, est_intervals, est_pitches): def match_notes(ref_intervals, ref_pitches, est_intervals, est_pitches, onset_tolerance=0.05, pitch_tolerance=50.0, offset_ratio=0.2, - offset_min_tolerance=0.05): + offset_min_tolerance=0.05, strict=False): """Compute a maximum matching between reference and estimated notes, subject to onset, pitch and (optionally) offset constraints. Given two note sequences represented by ``ref_intervals``, ``ref_pitches``, - ``est_intervals`` and ``est_pitches`` (see ``io.load_valued_intervals``), - we seek the largest set of correspondences ``(i, j)`` such that: + ``est_intervals`` and ``est_pitches`` + (see :func:`mir_eval.io.load_valued_intervals`), we seek the largest set + of correspondences ``(i, j)`` such that: 1. The onset of ref note i is within ``onset_tolerance`` of the onset of est note j. @@ -157,6 +170,11 @@ def match_notes(ref_intervals, ref_pitches, est_intervals, est_pitches, for an explanation of how the offset tolerance is determined. Note: this parameter only influences the results if ``offset_ratio`` is not ``None``. + strict: bool + If ``strict=False`` (the default), threshold checks for onset, offset, + and pitch matching are performed using ``<=`` (less than or equal). If + ``strict=True``, the threshold checks are performed using ``<`` (less + than). Returns ------- @@ -165,15 +183,21 @@ def match_notes(ref_intervals, ref_pitches, est_intervals, est_pitches, ``matching[i] == (i, j)`` where reference note i matches estimate note j. """ + # set the comparison function + if strict: + cmp_func = np.less + else: + cmp_func = np.less_equal + # check for onset matches onset_distances = np.abs(np.subtract.outer(ref_intervals[:, 0], est_intervals[:, 0])) - onset_hit_matrix = onset_distances < onset_tolerance + onset_hit_matrix = cmp_func(onset_distances, onset_tolerance) # check for pitch matches pitch_distances = np.abs(1200*np.log2(np.divide.outer(ref_pitches, est_pitches))) - pitch_hit_matrix = pitch_distances < pitch_tolerance + pitch_hit_matrix = cmp_func(pitch_distances, pitch_tolerance) # check for offset matches if offset_ratio is not None if offset_ratio is not None: @@ -181,11 +205,10 @@ def match_notes(ref_intervals, ref_pitches, est_intervals, est_pitches, est_intervals[:, 1])) ref_durations = util.intervals_to_durations(ref_intervals) offset_tolerances = 0.5 * offset_ratio * ref_durations - offset_tolerance_matrix = np.tile(offset_tolerances, - (offset_distances.shape[1], 1)).T - min_tolerance_inds = offset_tolerance_matrix < offset_min_tolerance - offset_tolerance_matrix[min_tolerance_inds] = offset_min_tolerance - offset_hit_matrix = offset_distances < offset_tolerance_matrix + min_tolerance_inds = offset_tolerances < offset_min_tolerance + offset_tolerances[min_tolerance_inds] = offset_min_tolerance + offset_hit_matrix = \ + cmp_func(offset_distances, offset_tolerances.reshape(-1, 1)) else: offset_hit_matrix = np.ones_like(onset_hit_matrix) @@ -211,7 +234,8 @@ def match_notes(ref_intervals, ref_pitches, est_intervals, est_pitches, def precision_recall_f1(ref_intervals, ref_pitches, est_intervals, est_pitches, onset_tolerance=0.05, pitch_tolerance=50.0, - offset_ratio=0.2, offset_min_tolerance=0.05): + offset_ratio=0.2, offset_min_tolerance=0.05, + strict=False): """Compute the Precision, Recall and F-measure of correct vs incorrectly transcribed notes. "Correctness" is determined based on note onset, pitch and (optionally) offset: an estimated note is assumed correct if its onset @@ -265,6 +289,11 @@ def precision_recall_f1(ref_intervals, ref_pitches, est_intervals, est_pitches, for an explanation of how the offset tolerance is determined. Note: this parameter only influences the results if offset_ratio is not ``None``. + strict: bool + If ``strict=False`` (the default), threshold checks for onset, offset, + and pitch matching are performed using ``<=`` (less than or equal). If + ``strict=True``, the threshold checks are performed using ``<`` (less + than). Returns ------- @@ -284,7 +313,8 @@ def precision_recall_f1(ref_intervals, ref_pitches, est_intervals, est_pitches, est_pitches, onset_tolerance=onset_tolerance, pitch_tolerance=pitch_tolerance, offset_ratio=offset_ratio, - offset_min_tolerance=offset_min_tolerance) + offset_min_tolerance=offset_min_tolerance, + strict=strict) precision = float(len(matching))/len(est_pitches) recall = float(len(matching))/len(ref_pitches)