Skip to content

Commit

Permalink
Add framewise evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
Faro authored and craffel committed Jul 11, 2016
1 parent 5887679 commit 67700d1
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
1 change: 0 additions & 1 deletion evaluators/separation_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import sys
import os
import glob
import os
import numpy as np
import eval_utilities

Expand Down
70 changes: 67 additions & 3 deletions mir_eval/separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,67 @@ def bss_eval_sources(reference_sources, estimated_sources):
return (sdr[idx], sir[idx], sar[idx], popt)


def bss_eval_sources_framewise(
reference_sources, estimated_sources, win, hop
):
"""Framewise computation of bss_eval_sources
Examples
--------
>>> # reference_sources[n] should be an ndarray of samples of the
>>> # n'th reference source
>>> # estimated_sources[n] should be the same for the n'th estimated
>>> # source
>>> (sdr, sir, sar,
... perm) = mir_eval.separation.bss_eval_sources_framewise(
reference_sources,
... estimated_sources)
Parameters
----------
reference_sources : np.ndarray, shape=(nsrc, nsampl)
matrix containing true sources
estimated_sources : np.ndarray, shape=(nsrc, nsampl)
matrix containing estimated sources
win : int
Window length
hop : int
Hop size
Returns
-------
sdr : np.ndarray, shape=(nsrc, nframes)
vector of Signal to Distortion Ratios (SDR)
sir : np.ndarray, shape=(nsrc, nframes)
vector of Source to Interference Ratios (SIR)
sar : np.ndarray, shape=(nsrc, nframes)
vector of Sources to Artifacts Ratios (SAR)
perm : np.ndarray, shape=(nsrc, nframes)
vector containing the best ordering of estimated sources in
the mean SIR sense (estimated source number perm[j] corresponds to
true source number j)
"""
nsrc = reference_sources.shape[0]

nwin = int(
np.floor((reference_sources.shape[1] - win + hop) / hop)
)

SDR = np.empty((nsrc, nwin))
SIR = np.empty((nsrc, nwin))
SAR = np.empty((nsrc, nwin))
perm = np.empty((nsrc, nwin))

for k in range(nwin):
K = slice(k * hop, k * hop + win)
SDR[:, k], SIR[:, k], SAR[:, k], perm[:, k] = bss_eval_sources(
reference_sources[:, K], estimated_sources[:, K]
)

return SDR, SIR, SAR, perm


def _bss_decomp_mtifilt(reference_sources, estimated_source, j, flen):
"""Decomposition of an estimated source image into four components
representing respectively the true source image, spatial (or filtering)
Expand Down Expand Up @@ -359,9 +420,12 @@ def evaluate(reference_sources, estimated_sources, **kwargs):
# Compute all the metrics
scores = collections.OrderedDict()

sdr, sir, sar, perm = util.filter_kwargs(bss_eval_sources,
reference_sources,
estimated_sources, **kwargs)
sdr, sir, sar, perm = util.filter_kwargs(
bss_eval_sources,
reference_sources,
estimated_sources,
**kwargs
)

scores['Source to Distortion'] = sdr.tolist()
scores['Source to Interference'] = sir.tolist()
Expand Down

0 comments on commit 67700d1

Please sign in to comment.