diff --git a/crepe/core.py b/crepe/core.py index 9d91c6a..3fc4ebc 100644 --- a/crepe/core.py +++ b/crepe/core.py @@ -25,7 +25,7 @@ def build_and_load_model(model_capacity): """ Build the CNN model and load the weights - + Parameters ---------- model_capacity : 'tiny', 'small', 'medium', 'large', or 'full' @@ -92,16 +92,21 @@ def output_path(file, suffix, output_dir): return path -def to_local_average_cents(salience, center=None): +def bin_to_cents(bin): """ - find the weighted average cents near the argmax bin + map bin numbers to cents values """ - - if not hasattr(to_local_average_cents, 'cents_mapping'): - # the bin number-to-cents mapping - to_local_average_cents.mapping = ( + if not hasattr(bin_to_cents, 'cents_mapping'): + bin_to_cents.mapping = ( np.linspace(0, 7180, 360) + 1997.3794084376191) + return bin_to_cents.mapping[bin] + + +def to_local_average_cents(salience, center=None): + """ + find the weighted average cents near the argmax bin + """ if salience.ndim == 1: if center is None: center = int(np.argmax(salience)) @@ -109,7 +114,7 @@ def to_local_average_cents(salience, center=None): end = min(len(salience), center + 5) salience = salience[start:end] product_sum = np.sum( - salience * to_local_average_cents.mapping[start:end]) + salience * bin_to_cents(np.arange(start, end))) weight_sum = np.sum(salience) return product_sum / weight_sum if salience.ndim == 2: @@ -124,43 +129,32 @@ def to_viterbi_cents(salience): Find the Viterbi path using a transition prior that induces pitch continuity. """ - from hmmlearn import hmm - - # uniform prior on the starting pitch - starting = np.ones(360) / 360 + from librosa.sequence import viterbi # transition probabilities inducing continuous pitch xx, yy = np.meshgrid(range(360), range(360)) transition = np.maximum(12 - abs(xx - yy), 0) transition = transition / np.sum(transition, axis=1)[:, None] - # emission probability = fixed probability for self, evenly distribute the - # others - self_emission = 0.1 - emission = (np.eye(360) * self_emission + np.ones(shape=(360, 360)) * - ((1 - self_emission) / 360)) - - # fix the model parameters because we are not optimizing the model - model = hmm.MultinomialHMM(360, starting, transition) - model.startprob_, model.transmat_, model.emissionprob_ = \ - starting, transition, emission + # compute the posterior distribution from the logits + posterior = np.exp(salience) / np.sum(np.exp(salience), axis=1, + keepdims=True) - # find the Viterbi path - observations = np.argmax(salience, axis=1) - path = model.predict(observations.reshape(-1, 1), [len(observations)]) + # determine the path through the posterior distribution + path = viterbi(posterior.T, transition) - return np.array([to_local_average_cents(salience[i, :], path[i]) for i in - range(len(observations))]) + # convert bin indices to cents + return bin_to_cents(path) def get_activation(audio, sr, model_capacity='full', center=True, step_size=10, verbose=1): """ - + Parameters ---------- audio : np.ndarray [shape=(N,) or (N, C)] - The audio samples. Multichannel audio will be downmixed. + The audio samples. Multichannel audio will be downmixed. sr : int Sample rate of the audio samples. The audio will be resampled if the sample rate is not 16 kHz, which is expected by the model. @@ -216,11 +210,11 @@ def predict(audio, sr, model_capacity='full', viterbi=False, center=True, step_size=10, verbose=1): """ Perform pitch estimation on given audio - + Parameters ---------- audio : np.ndarray [shape=(N,) or (N, C)] - The audio samples. Multichannel audio will be downmixed. + The audio samples. Multichannel audio will be downmixed. sr : int Sample rate of the audio samples. The audio will be resampled if the sample rate is not 16 kHz, which is expected by the model. @@ -242,7 +236,7 @@ def predict(audio, sr, model_capacity='full', Returns ------- A 4-tuple consisting of: - + time: np.ndarray [shape=(T,)] The timestamps on which the pitch was estimated frequency: np.ndarray [shape=(T,)] diff --git a/setup.py b/setup.py index d401c8a..ea26233 100644 --- a/setup.py +++ b/setup.py @@ -73,8 +73,8 @@ 'matplotlib>=2.1.0', 'resampy>=0.2.0,<0.3.0', 'h5py>=2.7.0,<3.0.0', - 'hmmlearn>=0.2.0,<0.3.0', 'imageio>=2.3.0', + 'librosa>=0.6.2', 'scikit-learn>=0.16' ], package_data={