diff --git a/machine_learning/loss_functions.py b/machine_learning/loss_functions.py index 0bd9aa8b5401..8bc16e1c1c7a 100644 --- a/machine_learning/loss_functions.py +++ b/machine_learning/loss_functions.py @@ -645,6 +645,11 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float - y_true: True class probabilities - y_pred: Predicted class probabilities + >>> true_labels = np.array([0, 0.4, 0.6]) + >>> predicted_probs = np.array([0.3, 0.3, 0.4]) + >>> float(kullback_leibler_divergence(true_labels, predicted_probs)) + 0.35835189384561095 + >>> true_labels = np.array([0.2, 0.3, 0.5]) >>> predicted_probs = np.array([0.3, 0.3, 0.4]) >>> float(kullback_leibler_divergence(true_labels, predicted_probs)) @@ -659,6 +664,9 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float if len(y_true) != len(y_pred): raise ValueError("Input arrays must have the same length.") + filter_array = y_true != 0 + y_true = y_true[filter_array] + y_pred = y_pred[filter_array] kl_loss = y_true * np.log(y_true / y_pred) return np.sum(kl_loss)