From 892dd517eb17698e4d1aac732353767330f1bfd4 Mon Sep 17 00:00:00 2001 From: KikuchiTomo <48982211+KikuchiTomo@users.noreply.github.com> Date: Thu, 27 Jul 2023 17:06:32 +0900 Subject: [PATCH] fix accuracy calculation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Accuracy = (Σ_i w[i, i])/(Σ_i Σ_j w[i, j]) --- metrics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/metrics.py b/metrics.py index 3ae3a35..055cd9d 100644 --- a/metrics.py +++ b/metrics.py @@ -22,6 +22,6 @@ def acc(y_true, y_pred): w = np.zeros((D, D), dtype=np.int64) for i in range(y_pred.size): w[y_pred[i], y_true[i]] += 1 - from sklearn.utils.linear_assignment_ import linear_assignment - ind = linear_assignment(w.max() - w) - return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size \ No newline at end of file + from scipy.optimize import linear_sum_assignment + ind = linear_sum_assignment(np.amax(w) - w) + return sum([w[i, i] for i, j in ind]) * 1.0 / y_pred.size