diff --git a/evaluators/separation_eval.py b/evaluators/separation_eval.py index a9c8e0b7..2f6bde16 100755 --- a/evaluators/separation_eval.py +++ b/evaluators/separation_eval.py @@ -12,7 +12,6 @@ import sys import os import glob -import os import numpy as np import eval_utilities diff --git a/mir_eval/separation.py b/mir_eval/separation.py index 43d19632..a629613d 100644 --- a/mir_eval/separation.py +++ b/mir_eval/separation.py @@ -185,6 +185,67 @@ def bss_eval_sources(reference_sources, estimated_sources): return (sdr[idx], sir[idx], sar[idx], popt) +def bss_eval_sources_framewise( + reference_sources, estimated_sources, win, hop +): + """Framewise computation of bss_eval_sources + + Examples + -------- + >>> # reference_sources[n] should be an ndarray of samples of the + >>> # n'th reference source + >>> # estimated_sources[n] should be the same for the n'th estimated + >>> # source + >>> (sdr, sir, sar, + ... perm) = mir_eval.separation.bss_eval_sources_framewise( + reference_sources, + ... estimated_sources) + + Parameters + ---------- + reference_sources : np.ndarray, shape=(nsrc, nsampl) + matrix containing true sources + estimated_sources : np.ndarray, shape=(nsrc, nsampl) + matrix containing estimated sources + win : int + Window length + hop : int + Hop size + + Returns + ------- + sdr : np.ndarray, shape=(nsrc, nframes) + vector of Signal to Distortion Ratios (SDR) + sir : np.ndarray, shape=(nsrc, nframes) + vector of Source to Interference Ratios (SIR) + sar : np.ndarray, shape=(nsrc, nframes) + vector of Sources to Artifacts Ratios (SAR) + perm : np.ndarray, shape=(nsrc, nframes) + vector containing the best ordering of estimated sources in + the mean SIR sense (estimated source number perm[j] corresponds to + true source number j) + + """ + nsrc = reference_sources.shape[0] + + nwin = int( + np.floor((reference_sources.shape[1] - win + hop) / hop) + ) + + SDR = np.empty((nsrc, nwin)) + SIR = np.empty((nsrc, nwin)) + SAR = np.empty((nsrc, nwin)) + perm = np.empty((nsrc, nwin)) + + for k in range(nwin): + K = slice(k * hop, k * hop + win) + SDR[:, k], SIR[:, k], SAR[:, k], perm[:, k] = bss_eval_sources( + reference_sources[:, K], estimated_sources[:, K] + ) + + return SDR, SIR, SAR, perm + + def _bss_decomp_mtifilt(reference_sources, estimated_source, j, flen): """Decomposition of an estimated source image into four components representing respectively the true source image, spatial (or filtering) @@ -359,9 +420,12 @@ def evaluate(reference_sources, estimated_sources, **kwargs): # Compute all the metrics scores = collections.OrderedDict() - sdr, sir, sar, perm = util.filter_kwargs(bss_eval_sources, - reference_sources, - estimated_sources, **kwargs) + sdr, sir, sar, perm = util.filter_kwargs( + bss_eval_sources, + reference_sources, + estimated_sources, + **kwargs + ) scores['Source to Distortion'] = sdr.tolist() scores['Source to Interference'] = sir.tolist()