Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Decoding from the categorical posterior distribution #38

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 26 additions & 32 deletions crepe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def build_and_load_model(model_capacity):
"""
Build the CNN model and load the weights

Parameters
----------
model_capacity : 'tiny', 'small', 'medium', 'large', or 'full'
Expand Down Expand Up @@ -92,24 +92,29 @@ def output_path(file, suffix, output_dir):
return path


def to_local_average_cents(salience, center=None):
def bin_to_cents(bin):
"""
find the weighted average cents near the argmax bin
map bin numbers to cents values
"""

if not hasattr(to_local_average_cents, 'cents_mapping'):
# the bin number-to-cents mapping
to_local_average_cents.mapping = (
if not hasattr(bin_to_cents, 'cents_mapping'):
bin_to_cents.mapping = (
np.linspace(0, 7180, 360) + 1997.3794084376191)

return bin_to_cents.mapping[bin]


def to_local_average_cents(salience, center=None):
"""
find the weighted average cents near the argmax bin
"""
if salience.ndim == 1:
if center is None:
center = int(np.argmax(salience))
start = max(0, center - 4)
end = min(len(salience), center + 5)
salience = salience[start:end]
product_sum = np.sum(
salience * to_local_average_cents.mapping[start:end])
salience * bin_to_cents(np.arange(start, end)))
weight_sum = np.sum(salience)
return product_sum / weight_sum
if salience.ndim == 2:
Expand All @@ -124,43 +129,32 @@ def to_viterbi_cents(salience):
Find the Viterbi path using a transition prior that induces pitch
continuity.
"""
from hmmlearn import hmm

# uniform prior on the starting pitch
starting = np.ones(360) / 360
from librosa.sequence import viterbi

# transition probabilities inducing continuous pitch
xx, yy = np.meshgrid(range(360), range(360))
transition = np.maximum(12 - abs(xx - yy), 0)
transition = transition / np.sum(transition, axis=1)[:, None]

# emission probability = fixed probability for self, evenly distribute the
# others
self_emission = 0.1
emission = (np.eye(360) * self_emission + np.ones(shape=(360, 360)) *
((1 - self_emission) / 360))

# fix the model parameters because we are not optimizing the model
model = hmm.MultinomialHMM(360, starting, transition)
model.startprob_, model.transmat_, model.emissionprob_ = \
starting, transition, emission
# compute the posterior distribution from the logits
posterior = np.exp(salience) / np.sum(np.exp(salience), axis=1,
keepdims=True)

# find the Viterbi path
observations = np.argmax(salience, axis=1)
path = model.predict(observations.reshape(-1, 1), [len(observations)])
# determine the path through the posterior distribution
path = viterbi(posterior.T, transition)

return np.array([to_local_average_cents(salience[i, :], path[i]) for i in
range(len(observations))])
# convert bin indices to cents
return bin_to_cents(path)


def get_activation(audio, sr, model_capacity='full', center=True, step_size=10,
verbose=1):
"""

Parameters
----------
audio : np.ndarray [shape=(N,) or (N, C)]
The audio samples. Multichannel audio will be downmixed.
The audio samples. Multichannel audio will be downmixed.
sr : int
Sample rate of the audio samples. The audio will be resampled if
the sample rate is not 16 kHz, which is expected by the model.
Expand Down Expand Up @@ -216,11 +210,11 @@ def predict(audio, sr, model_capacity='full',
viterbi=False, center=True, step_size=10, verbose=1):
"""
Perform pitch estimation on given audio

Parameters
----------
audio : np.ndarray [shape=(N,) or (N, C)]
The audio samples. Multichannel audio will be downmixed.
The audio samples. Multichannel audio will be downmixed.
sr : int
Sample rate of the audio samples. The audio will be resampled if
the sample rate is not 16 kHz, which is expected by the model.
Expand All @@ -242,7 +236,7 @@ def predict(audio, sr, model_capacity='full',
Returns
-------
A 4-tuple consisting of:

time: np.ndarray [shape=(T,)]
The timestamps on which the pitch was estimated
frequency: np.ndarray [shape=(T,)]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@
'matplotlib>=2.1.0',
'resampy>=0.2.0,<0.3.0',
'h5py>=2.7.0,<3.0.0',
'hmmlearn>=0.2.0,<0.3.0',
'imageio>=2.3.0',
'librosa>=0.6.2',
'scikit-learn>=0.16'
],
package_data={
Expand Down