diff --git a/metrics.py b/metrics.py index 3ae3a35..055cd9d 100644 --- a/metrics.py +++ b/metrics.py @@ -22,6 +22,6 @@ def acc(y_true, y_pred): w = np.zeros((D, D), dtype=np.int64) for i in range(y_pred.size): w[y_pred[i], y_true[i]] += 1 - from sklearn.utils.linear_assignment_ import linear_assignment - ind = linear_assignment(w.max() - w) - return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size \ No newline at end of file + from scipy.optimize import linear_sum_assignment + ind = linear_sum_assignment(np.amax(w) - w) + return sum([w[i, i] for i, j in ind]) * 1.0 / y_pred.size