-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explainer api for local classifiers #102
Changes from all commits
62c218d
ea1fff8
c4d75c5
299af62
0ea8956
1efd946
7dcb52f
1de360c
933b1f6
c57abed
a829ce0
33f2cbc
c06d8a7
c597fce
b79e5f4
8a643f1
606c1eb
ca6c654
9936dc3
82573be
d53e8d9
2449928
759489f
0771c08
4eb6f5c
3955521
7c2f4d2
b12bdc3
986b61c
eb11c0e
8c700e4
53a90a0
5e74762
b1f3656
2a12087
84f6e39
b09f8da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,304 @@ | ||
"""Explainer API for explaining predictions using shapley values.""" | ||
|
||
from copy import deepcopy | ||
|
||
import numpy as np | ||
|
||
from hiclass import ( | ||
LocalClassifierPerParentNode, | ||
LocalClassifierPerNode, | ||
LocalClassifierPerLevel, | ||
ConstantClassifier, | ||
) | ||
|
||
try: | ||
import xarray as xr | ||
except ImportError: | ||
xarray_installed = False | ||
else: | ||
xarray_installed = True | ||
|
||
try: | ||
import shap | ||
except ImportError: | ||
shap_installed = False | ||
else: | ||
shap_installed = True | ||
|
||
|
||
def _check_imports(): | ||
if not shap_installed: | ||
raise ImportError( | ||
"Shap is not installed. Please install it using `pip install shap` first." | ||
) | ||
elif not xarray_installed: | ||
raise ImportError( | ||
"xarray is not installed. Please install it using `pip install xarray` first." | ||
) | ||
|
||
|
||
class Explainer: | ||
"""Explainer class for returning shap values for each of the three hierarchical classifiers.""" | ||
|
||
def __init__(self, hierarchical_model, data=None, algorithm="auto", mode=""): | ||
""" | ||
Initialize the SHAP explainer for a hierarchical model. | ||
|
||
Parameters | ||
---------- | ||
hierarchical_model : HierarchicalClassifier | ||
The hierarchical classification model to explain. | ||
data : array-like or None, default=None | ||
The dataset used for creating the SHAP explainer. | ||
algorithm : str, default="auto" | ||
The algorithm to use for SHAP explainer. | ||
mode : str, default="" | ||
The mode of the SHAP explainer. Can be 'tree', 'gradient', 'deep', 'linear', or '' for default SHAP explainer. | ||
""" | ||
self.hierarchical_model = hierarchical_model | ||
self.algorithm = algorithm | ||
self.mode = mode | ||
self.data = data | ||
|
||
_check_imports() | ||
|
||
if mode == "tree": | ||
self.explainer = shap.TreeExplainer | ||
elif mode == "gradient": | ||
self.explainer = shap.GradientExplainer | ||
elif mode == "deep": | ||
self.explainer = shap.DeepExplainer | ||
elif mode == "linear": | ||
self.explainer = shap.LinearExplainer | ||
else: | ||
self.explainer = shap.Explainer | ||
|
||
def explain(self, X): | ||
""" | ||
Generate SHAP values for each node in the hierarchy for the given data. | ||
|
||
Parameters | ||
---------- | ||
X : array-like | ||
Training data to fit the SHAP explainer. | ||
|
||
Returns | ||
------- | ||
shap_values_dict : dict | ||
A dictionary of SHAP values for each node. | ||
""" | ||
_check_imports() | ||
if isinstance(self.hierarchical_model, LocalClassifierPerParentNode): | ||
return self._explain_with_xr(X) | ||
elif isinstance(self.hierarchical_model, LocalClassifierPerLevel): | ||
return self._explain_lcpl(X) | ||
elif isinstance(self.hierarchical_model, LocalClassifierPerNode): | ||
return self._explain_lcpn(X) | ||
else: | ||
raise ValueError(f"Invalid model: {self.hierarchical_model}.") | ||
|
||
def _explain_with_dict(self, X): | ||
""" | ||
Generate SHAP values for each node using Local Classifier Per Parent Node (LCPPN) strategy. | ||
|
||
Parameters | ||
---------- | ||
X : array-like | ||
Sample data for which to generate SHAP values. | ||
|
||
traverse_prediction : True or False | ||
If True, restricts calculation of shap values to only traversed hierarchy as predicted by hiclass model. | ||
|
||
Returns | ||
------- | ||
shap_values_dict : dict | ||
A dictionary of SHAP values for each node. | ||
""" | ||
shap_values_dict = {} | ||
traversed_nodes = self._get_traversed_nodes(X) | ||
for node in traversed_nodes: | ||
local_classifier = self.hierarchical_model.hierarchy_.nodes[node][ | ||
"classifier" | ||
] | ||
|
||
# Create explainer with train data | ||
local_explainer = deepcopy(self.explainer)(local_classifier, self.data) | ||
shap_values = np.array(local_explainer.shap_values(X)) | ||
|
||
if len(shap_values.shape) < 3: | ||
shap_values = shap_values.reshape( | ||
1, shap_values.shape[0], shap_values.shape[1] | ||
) | ||
|
||
shap_values_dict[node] = shap_values | ||
|
||
for node in self.hierarchical_model.hierarchy_.nodes: | ||
if node not in traversed_nodes: | ||
local_classifier = self.hierarchical_model.hierarchy_.nodes[node] | ||
if len(local_classifier) != 0: | ||
shap_val = np.full( | ||
( | ||
len(local_classifier["classifier"].classes_), | ||
X.shape[0], | ||
X.shape[1], | ||
), | ||
np.nan, | ||
) | ||
shap_values_dict[node] = shap_val | ||
return shap_values_dict | ||
|
||
def _explain_with_xr(self, X): | ||
""" | ||
Generate SHAP values for each node using Local Classifier Per Parent Node (LCPPN) strategy. | ||
|
||
Parameters | ||
---------- | ||
X : array-like | ||
Sample data for which to generate SHAP values. | ||
|
||
Returns | ||
------- | ||
shap_values_dict : dict | ||
A dictionary of SHAP values for each node. | ||
""" | ||
explanations = [] | ||
for sample in X: | ||
explanation = self._calculate_shap_values(sample.reshape(1, -1)) | ||
explanations.append(explanation) | ||
return explanations | ||
|
||
def _explain_lcpn(self, X): | ||
shap_values_dict = {} | ||
for node in self.hierarchical_model.hierarchy_.nodes: | ||
if node == self.hierarchical_model.root_: | ||
continue | ||
|
||
if isinstance( | ||
self.hierarchical_model.hierarchy_.nodes[node]["classifier"], | ||
ConstantClassifier.ConstantClassifier, | ||
): | ||
continue | ||
|
||
local_classifier = self.hierarchical_model.hierarchy_.nodes[node][ | ||
"classifier" | ||
] | ||
|
||
local_explainer = deepcopy(self.explainer)(local_classifier, self.data) | ||
|
||
# Calculate SHAP values for the given sample X | ||
shap_values = np.array(local_explainer.shap_values(X)) | ||
shap_values_dict[node] = shap_values | ||
|
||
return shap_values_dict | ||
|
||
def _explain_lcpl(self, X): | ||
""" | ||
Generate SHAP values for each node using Local Classifier Per Level (LCPL) strategy. | ||
|
||
Parameters | ||
---------- | ||
X : array-like | ||
Sample data for which to generate SHAP values. | ||
|
||
Returns | ||
------- | ||
shap_values_dict : dict | ||
A dictionary of SHAP values for each node. | ||
""" | ||
shap_values_dict = {} | ||
start_level = 1 | ||
if len(self.hierarchical_model.local_classifiers_[start_level]) == 1: | ||
start_level = 2 | ||
|
||
for level in range( | ||
start_level, len(self.hierarchical_model.local_classifiers_) | ||
): | ||
local_classifier = self.hierarchical_model.local_classifiers_[level] | ||
local_explainer = deepcopy(self.explainer)(local_classifier, self.data) | ||
|
||
# Calculate SHAP values for the given sample X | ||
shap_values = np.array(local_explainer.shap_values(X)) | ||
shap_values_dict[level] = shap_values | ||
|
||
return shap_values_dict | ||
|
||
def _get_traversed_nodes(self, samples): | ||
# Helper function to return traversed nodes | ||
if isinstance(self.hierarchical_model, LocalClassifierPerParentNode): | ||
traversals = [] | ||
start_node = self.hierarchical_model.root_ | ||
for x in samples: | ||
current = start_node | ||
traversal_order = [] | ||
while self.hierarchical_model.hierarchy_.neighbors(current): | ||
if ( | ||
"classifier" | ||
not in self.hierarchical_model.hierarchy_.nodes[current] | ||
): | ||
break # Break if reached leaf node | ||
traversal_order.append(current) | ||
successor = self.hierarchical_model.hierarchy_.nodes[current][ | ||
"classifier" | ||
].predict(x.reshape(1, -1))[0] | ||
current = successor | ||
traversals.append(traversal_order) | ||
return traversals | ||
elif isinstance(self.hierarchical_model, LocalClassifierPerNode): | ||
pass | ||
elif isinstance(self.hierarchical_model, LocalClassifierPerLevel): | ||
pass | ||
|
||
def _calculate_shap_values(self, X): | ||
if not xarray_installed: | ||
raise ImportError( | ||
"xarray is not installed. Please install it using `pip install xarray` before using " | ||
"this method." | ||
) | ||
traversed_nodes = self._get_traversed_nodes(X)[0] | ||
datasets = [] | ||
for node in traversed_nodes: | ||
local_classifier = self.hierarchical_model.hierarchy_.nodes[node][ | ||
"classifier" | ||
] | ||
|
||
# Create a SHAP explainer for the local classifier | ||
local_explainer = deepcopy(self.explainer)(local_classifier, self.data) | ||
|
||
# Calculate SHAP values for the given sample X | ||
shap_values = np.array( | ||
local_explainer.shap_values(X, check_additivity=False) | ||
) | ||
if len(shap_values.shape) < 3: | ||
shap_values = shap_values.reshape( | ||
1, shap_values.shape[0], shap_values.shape[1] | ||
) | ||
|
||
predict_proba = xr.DataArray( | ||
local_classifier.predict_proba(X)[0], | ||
dims=["class"], | ||
coords={ | ||
"class": local_classifier.classes_, | ||
}, | ||
) | ||
classes = xr.DataArray( | ||
local_classifier.classes_, | ||
dims=["class"], | ||
coords={"class": local_classifier.classes_}, | ||
) | ||
|
||
shap_val_local = xr.DataArray( | ||
shap_values, | ||
dims=["class", "sample", "feature"], | ||
) | ||
|
||
local_dataset = xr.Dataset( | ||
{ | ||
"node": node.split(self.hierarchical_model.separator_)[-1], | ||
"predicted_class": local_classifier.predict(X), | ||
"predict_proba": predict_proba, | ||
"classes": classes, | ||
"shap_values": shap_val_local, | ||
} | ||
) | ||
datasets.append(local_dataset) | ||
return datasets |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -224,3 +224,44 @@ def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False): | |
self.logger_.info("Fitting local classifiers") | ||
nodes = self._get_parents() | ||
self._fit_node_classifier(nodes, local_mode, use_joblib) | ||
|
||
def get_predict_proba(self, X): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jannikgro I have updated the function to return a dict, however it currently just returns predict_proba for all nodes in the hierarchy. I'm woking on another version which only gives the dict for only traversed nodes. |
||
""" | ||
Prediction probabilities for each class for the given data. | ||
|
||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
The input samples. Internally, its dtype will be converted | ||
to ``dtype=np.float32``. If a sparse matrix is provided, it will be | ||
converted into a sparse ``csr_matrix``. | ||
Returns | ||
------- | ||
prediction_probabilities : dict or list of dicts if X contains multiple samples. | ||
Each dict contains the node as its key and prediction probabilities for | ||
the predicted classes as value. | ||
""" | ||
prediction_probabilities = [] | ||
for sample in X: | ||
predict_proba_dict = {} | ||
y_pred = self.predict(sample.reshape(1, -1)) | ||
traversal_path = str(y_pred[0][0]) | ||
for pred in y_pred[0][1:]: | ||
traversal_path = traversal_path + self.separator_ + pred | ||
|
||
for i in range(self.max_levels_)[:-1]: | ||
node = self.separator_.join( | ||
traversal_path.split(self.separator_)[: i + 1] | ||
) | ||
|
||
local_classifier = self.hierarchy_.nodes[node]["classifier"] | ||
pred_probabilities = local_classifier.predict_proba( | ||
sample.reshape(1, -1) | ||
) | ||
predict_proba_dict[node] = pred_probabilities | ||
prediction_probabilities.append(predict_proba_dict) | ||
|
||
if len(prediction_probabilities) == 1: | ||
return prediction_probabilities[0] | ||
|
||
return prediction_probabilities |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is cool to have as a dictionary but not 100% useful if we do not know which was the predicted and/or the correct path the model has chosen. Would it be possible to return the predicted label and the corresponding shap values?