From cf651169c160fd901b588af714471da39366f8e5 Mon Sep 17 00:00:00 2001 From: Anna Beranger Date: Wed, 27 Mar 2024 15:22:19 +0100 Subject: [PATCH] xgboost interpretable dimensions --- sinr/text/evaluate.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/sinr/text/evaluate.py b/sinr/text/evaluate.py index 80edad7..d39d351 100644 --- a/sinr/text/evaluate.py +++ b/sinr/text/evaluate.py @@ -356,3 +356,39 @@ def clf_score(clf, X_test, y_test, scoring='accuracy', params={}): score = getattr(metrics, scoring+'_score') y_pred = clf.predict(X_test) return score(y_test, y_pred, **params) + +def clf_xgb_interpretability(sinr_vec, xgb, interpreter,topk_dim=10, topk=5, importance_type='gain'): + """Interpretability of main dimensions used by the xgboost classifier + :param sinr_vec: SINrVectors object from which datas were vectorized + :type sinr_vec: SINrVectors + :param xgb: fitted xgboost classifier + :type xgb: xgboost.XGBClassifier + :param interpreter: whether stereotypes or descriptors are requested + :type interpreter: str + :param topk_dim: Number of features requested among the main features used by the classifier (Default value = 10) + :type topk_dim: int + :param topk: `topk` value to consider on each dimension (Default value = 5) + :type topk: int + :param importance_type: ‘weight’: the number of times a feature is used to split the data across all trees, + ‘gain’: the average gain across all splits the feature is used in, + ‘cover’: the average coverage across all splits the feature is used in, + ‘total_gain’: the total gain across all splits the feature is used in + ‘total_cover’: the total coverage across all splits the feature is used in + :type importance_type: str + + :returns: Interpreters of dimensions, importance of dimensions + :rtype: list of set of object, list of tuple + + """ + + features = xgb.get_booster().get_score(importance_type=importance_type) + features = dict(sorted(features.items(), key=lambda x: x[1], reverse=True)) + features_index = [int(f[1:]) for f in list(features.keys())[:topk_dim]] + features_importance = list(features.items())[:topk_dim] + + if interpreter=='descriptors': + dim = [sinr_vec.get_dimension_descriptors_idx(index, topk=topk) for index in features_index] + elif interpreter=='stereotypes': + dim = [sinr_vec.get_dimension_stereotypes_idx(index, topk=topk) for index in features_index] + + return dim, features_importance