Skip to content

Commit

Permalink
xgboost interpretable dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
aberanger committed Mar 27, 2024
1 parent 41fef29 commit cf65116
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions sinr/text/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,39 @@ def clf_score(clf, X_test, y_test, scoring='accuracy', params={}):
score = getattr(metrics, scoring+'_score')
y_pred = clf.predict(X_test)
return score(y_test, y_pred, **params)

def clf_xgb_interpretability(sinr_vec, xgb, interpreter,topk_dim=10, topk=5, importance_type='gain'):
"""Interpretability of main dimensions used by the xgboost classifier
:param sinr_vec: SINrVectors object from which datas were vectorized
:type sinr_vec: SINrVectors
:param xgb: fitted xgboost classifier
:type xgb: xgboost.XGBClassifier
:param interpreter: whether stereotypes or descriptors are requested
:type interpreter: str
:param topk_dim: Number of features requested among the main features used by the classifier (Default value = 10)
:type topk_dim: int
:param topk: `topk` value to consider on each dimension (Default value = 5)
:type topk: int
:param importance_type: ‘weight’: the number of times a feature is used to split the data across all trees,
‘gain’: the average gain across all splits the feature is used in,
‘cover’: the average coverage across all splits the feature is used in,
‘total_gain’: the total gain across all splits the feature is used in
‘total_cover’: the total coverage across all splits the feature is used in
:type importance_type: str
:returns: Interpreters of dimensions, importance of dimensions
:rtype: list of set of object, list of tuple
"""

features = xgb.get_booster().get_score(importance_type=importance_type)
features = dict(sorted(features.items(), key=lambda x: x[1], reverse=True))
features_index = [int(f[1:]) for f in list(features.keys())[:topk_dim]]
features_importance = list(features.items())[:topk_dim]

if interpreter=='descriptors':
dim = [sinr_vec.get_dimension_descriptors_idx(index, topk=topk) for index in features_index]
elif interpreter=='stereotypes':
dim = [sinr_vec.get_dimension_stereotypes_idx(index, topk=topk) for index in features_index]

return dim, features_importance

0 comments on commit cf65116

Please sign in to comment.