Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explainer api for local classifiers #102

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
62c218d
added initial implementation of explainer api for lcppn
Jan 9, 2024
ea1fff8
fixed lints
Jan 10, 2024
c4d75c5
fixed lints
Jan 14, 2024
299af62
added an _explain_lcppn implementation and some tests provided
iwan-tee Jan 14, 2024
0ea8956
modified docstrings
Jan 14, 2024
1efd946
explainer for lcpn implemented + tests added and some cases fixed
iwan-tee Jan 14, 2024
7dcb52f
Merge branch 'explainer_api_lcpn' into explainer_api
iwan-tee Jan 14, 2024
1de360c
tests added + some bugs fixed
iwan-tee Jan 15, 2024
933b1f6
base
iwan-tee Jan 15, 2024
c57abed
basic implementation
iwan-tee Jan 18, 2024
a829ce0
LCPL explanator implementation + test
iwan-tee Jan 23, 2024
33f2cbc
added tests for hierarchy without roots
Jan 26, 2024
c06d8a7
check on root node added
iwan-tee Jan 26, 2024
c597fce
minor updates
Jan 26, 2024
b79e5f4
codestyling
iwan-tee Jan 26, 2024
8a643f1
codestyling
iwan-tee Jan 26, 2024
606c1eb
Merge branch 'explainer_master' into explainer_api_lcpl
ashishpatel16 Jan 26, 2024
ca6c654
Update Explainer.py
ashishpatel16 Jan 26, 2024
9936dc3
Merge pull request #1 from ashishpatel16/explainer_api_lcpl
ashishpatel16 Jan 26, 2024
82573be
Merge pull request #2 from ashishpatel16/explainer_api_lcpn
ashishpatel16 Jan 26, 2024
d53e8d9
added support for xarray for lcppn
Jan 29, 2024
2449928
Merge branch 'explainer_master' into explainer_api
ashishpatel16 Jan 29, 2024
759489f
Update Explainer.py
ashishpatel16 Jan 29, 2024
0771c08
Update Explainer.py
ashishpatel16 Jan 29, 2024
4eb6f5c
fixed errors with classifier with single class
Jan 30, 2024
3955521
updated test cases and removed cached explainers
Feb 1, 2024
7c2f4d2
removed cached explainers
Feb 1, 2024
b12bdc3
modified predict proba to return dict
Feb 1, 2024
986b61c
Merge branch 'main' into explainer_api
ashishpatel16 Feb 2, 2024
eb11c0e
updated get_predict_proba to return only traversed prediction probabi…
Feb 3, 2024
8c700e4
updated fork
Feb 3, 2024
53a90a0
separate test file for explainer
Feb 3, 2024
5e74762
Update Explainer.py
ashishpatel16 Feb 5, 2024
b1f3656
_get_traversed_nodes edited
iwan-tee Feb 6, 2024
2a12087
fixed lints
Feb 12, 2024
84f6e39
fixed conflicts
Feb 12, 2024
b09f8da
refactored and cleaned up code
Feb 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ name = "pypi"
networkx = "*"
numpy = "*"
scikit-learn = "*"
shap = "*"
xarray = "*"

[dev-packages]
pytest = "*"
Expand Down
1,940 changes: 1,124 additions & 816 deletions Pipfile.lock

Large diffs are not rendered by default.

304 changes: 304 additions & 0 deletions hiclass/Explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
"""Explainer API for explaining predictions using shapley values."""

from copy import deepcopy

import numpy as np

from hiclass import (
LocalClassifierPerParentNode,
LocalClassifierPerNode,
LocalClassifierPerLevel,
ConstantClassifier,
)

try:
import xarray as xr
except ImportError:
xarray_installed = False
else:
xarray_installed = True

try:
import shap
except ImportError:
shap_installed = False
else:
shap_installed = True


def _check_imports():
if not shap_installed:
raise ImportError(
"Shap is not installed. Please install it using `pip install shap` first."
)
elif not xarray_installed:
raise ImportError(
"xarray is not installed. Please install it using `pip install xarray` first."
)


class Explainer:
"""Explainer class for returning shap values for each of the three hierarchical classifiers."""

def __init__(self, hierarchical_model, data=None, algorithm="auto", mode=""):
"""
Initialize the SHAP explainer for a hierarchical model.

Parameters
----------
hierarchical_model : HierarchicalClassifier
The hierarchical classification model to explain.
data : array-like or None, default=None
The dataset used for creating the SHAP explainer.
algorithm : str, default="auto"
The algorithm to use for SHAP explainer.
mode : str, default=""
The mode of the SHAP explainer. Can be 'tree', 'gradient', 'deep', 'linear', or '' for default SHAP explainer.
"""
self.hierarchical_model = hierarchical_model
self.algorithm = algorithm
self.mode = mode
self.data = data

_check_imports()

if mode == "tree":
self.explainer = shap.TreeExplainer
elif mode == "gradient":
self.explainer = shap.GradientExplainer
elif mode == "deep":
self.explainer = shap.DeepExplainer
elif mode == "linear":
self.explainer = shap.LinearExplainer
else:
self.explainer = shap.Explainer

def explain(self, X):
"""
Generate SHAP values for each node in the hierarchy for the given data.

Parameters
----------
X : array-like
Training data to fit the SHAP explainer.

Returns
-------
shap_values_dict : dict
A dictionary of SHAP values for each node.
"""
_check_imports()
if isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
return self._explain_with_xr(X)
elif isinstance(self.hierarchical_model, LocalClassifierPerLevel):
return self._explain_lcpl(X)
elif isinstance(self.hierarchical_model, LocalClassifierPerNode):
return self._explain_lcpn(X)
else:
raise ValueError(f"Invalid model: {self.hierarchical_model}.")

def _explain_with_dict(self, X):
"""
Generate SHAP values for each node using Local Classifier Per Parent Node (LCPPN) strategy.

Parameters
----------
X : array-like
Sample data for which to generate SHAP values.

traverse_prediction : True or False
If True, restricts calculation of shap values to only traversed hierarchy as predicted by hiclass model.

Returns
-------
shap_values_dict : dict
A dictionary of SHAP values for each node.
"""
shap_values_dict = {}
traversed_nodes = self._get_traversed_nodes(X)
for node in traversed_nodes:
local_classifier = self.hierarchical_model.hierarchy_.nodes[node][
"classifier"
]

# Create explainer with train data
local_explainer = deepcopy(self.explainer)(local_classifier, self.data)
shap_values = np.array(local_explainer.shap_values(X))

if len(shap_values.shape) < 3:
shap_values = shap_values.reshape(
1, shap_values.shape[0], shap_values.shape[1]
)

shap_values_dict[node] = shap_values

for node in self.hierarchical_model.hierarchy_.nodes:
if node not in traversed_nodes:
local_classifier = self.hierarchical_model.hierarchy_.nodes[node]
if len(local_classifier) != 0:
shap_val = np.full(
(
len(local_classifier["classifier"].classes_),
X.shape[0],
X.shape[1],
),
np.nan,
)
shap_values_dict[node] = shap_val
return shap_values_dict

def _explain_with_xr(self, X):
"""
Generate SHAP values for each node using Local Classifier Per Parent Node (LCPPN) strategy.

Parameters
----------
X : array-like
Sample data for which to generate SHAP values.

Returns
-------
shap_values_dict : dict
A dictionary of SHAP values for each node.
"""
explanations = []
for sample in X:
explanation = self._calculate_shap_values(sample.reshape(1, -1))
explanations.append(explanation)
return explanations

def _explain_lcpn(self, X):
shap_values_dict = {}
for node in self.hierarchical_model.hierarchy_.nodes:
if node == self.hierarchical_model.root_:
continue

if isinstance(
self.hierarchical_model.hierarchy_.nodes[node]["classifier"],
ConstantClassifier.ConstantClassifier,
):
continue

local_classifier = self.hierarchical_model.hierarchy_.nodes[node][
"classifier"
]

local_explainer = deepcopy(self.explainer)(local_classifier, self.data)

# Calculate SHAP values for the given sample X
shap_values = np.array(local_explainer.shap_values(X))
shap_values_dict[node] = shap_values

return shap_values_dict

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool to have as a dictionary but not 100% useful if we do not know which was the predicted and/or the correct path the model has chosen. Would it be possible to return the predicted label and the corresponding shap values?


def _explain_lcpl(self, X):
"""
Generate SHAP values for each node using Local Classifier Per Level (LCPL) strategy.

Parameters
----------
X : array-like
Sample data for which to generate SHAP values.

Returns
-------
shap_values_dict : dict
A dictionary of SHAP values for each node.
"""
shap_values_dict = {}
start_level = 1
if len(self.hierarchical_model.local_classifiers_[start_level]) == 1:
start_level = 2

for level in range(
start_level, len(self.hierarchical_model.local_classifiers_)
):
local_classifier = self.hierarchical_model.local_classifiers_[level]
local_explainer = deepcopy(self.explainer)(local_classifier, self.data)

# Calculate SHAP values for the given sample X
shap_values = np.array(local_explainer.shap_values(X))
shap_values_dict[level] = shap_values

return shap_values_dict

def _get_traversed_nodes(self, samples):
# Helper function to return traversed nodes
if isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
traversals = []
start_node = self.hierarchical_model.root_
for x in samples:
current = start_node
traversal_order = []
while self.hierarchical_model.hierarchy_.neighbors(current):
if (
"classifier"
not in self.hierarchical_model.hierarchy_.nodes[current]
):
break # Break if reached leaf node
traversal_order.append(current)
successor = self.hierarchical_model.hierarchy_.nodes[current][
"classifier"
].predict(x.reshape(1, -1))[0]
current = successor
traversals.append(traversal_order)
return traversals
elif isinstance(self.hierarchical_model, LocalClassifierPerNode):
pass
elif isinstance(self.hierarchical_model, LocalClassifierPerLevel):
pass

def _calculate_shap_values(self, X):
if not xarray_installed:
raise ImportError(
"xarray is not installed. Please install it using `pip install xarray` before using "
"this method."
)
traversed_nodes = self._get_traversed_nodes(X)[0]
datasets = []
for node in traversed_nodes:
local_classifier = self.hierarchical_model.hierarchy_.nodes[node][
"classifier"
]

# Create a SHAP explainer for the local classifier
local_explainer = deepcopy(self.explainer)(local_classifier, self.data)

# Calculate SHAP values for the given sample X
shap_values = np.array(
local_explainer.shap_values(X, check_additivity=False)
)
if len(shap_values.shape) < 3:
shap_values = shap_values.reshape(
1, shap_values.shape[0], shap_values.shape[1]
)

predict_proba = xr.DataArray(
local_classifier.predict_proba(X)[0],
dims=["class"],
coords={
"class": local_classifier.classes_,
},
)
classes = xr.DataArray(
local_classifier.classes_,
dims=["class"],
coords={"class": local_classifier.classes_},
)

shap_val_local = xr.DataArray(
shap_values,
dims=["class", "sample", "feature"],
)

local_dataset = xr.Dataset(
{
"node": node.split(self.hierarchical_model.separator_)[-1],
"predicted_class": local_classifier.predict(X),
"predict_proba": predict_proba,
"classes": classes,
"shap_values": shap_val_local,
}
)
datasets.append(local_dataset)
return datasets
41 changes: 41 additions & 0 deletions hiclass/LocalClassifierPerParentNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,44 @@ def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False):
self.logger_.info("Fitting local classifiers")
nodes = self._get_parents()
self._fit_node_classifier(nodes, local_mode, use_joblib)

def get_predict_proba(self, X):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jannikgro I have updated the function to return a dict, however it currently just returns predict_proba for all nodes in the hierarchy. I'm woking on another version which only gives the dict for only traversed nodes.

"""
Prediction probabilities for each class for the given data.

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The input samples. Internally, its dtype will be converted
to ``dtype=np.float32``. If a sparse matrix is provided, it will be
converted into a sparse ``csr_matrix``.
Returns
-------
prediction_probabilities : dict or list of dicts if X contains multiple samples.
Each dict contains the node as its key and prediction probabilities for
the predicted classes as value.
"""
prediction_probabilities = []
for sample in X:
predict_proba_dict = {}
y_pred = self.predict(sample.reshape(1, -1))
traversal_path = str(y_pred[0][0])
for pred in y_pred[0][1:]:
traversal_path = traversal_path + self.separator_ + pred

for i in range(self.max_levels_)[:-1]:
node = self.separator_.join(
traversal_path.split(self.separator_)[: i + 1]
)

local_classifier = self.hierarchy_.nodes[node]["classifier"]
pred_probabilities = local_classifier.predict_proba(
sample.reshape(1, -1)
)
predict_proba_dict[node] = pred_probabilities
prediction_probabilities.append(predict_proba_dict)

if len(prediction_probabilities) == 1:
return prediction_probabilities[0]

return prediction_probabilities
4 changes: 4 additions & 0 deletions hiclass/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Init module for the library."""

import os
from ._version import get_versions
from .LocalClassifierPerLevel import LocalClassifierPerLevel
from .LocalClassifierPerNode import LocalClassifierPerNode
from .LocalClassifierPerParentNode import LocalClassifierPerParentNode
from .MultiLabelLocalClassifierPerNode import MultiLabelLocalClassifierPerNode
from .MultiLabelLocalClassifierPerParentNode import (
MultiLabelLocalClassifierPerParentNode,
)
from .Explainer import Explainer
from ._version import get_versions

__version__ = get_versions()["version"]
Expand All @@ -16,6 +19,7 @@
"LocalClassifierPerNode",
"LocalClassifierPerParentNode",
"LocalClassifierPerLevel",
"Explainer",
"MultiLabelLocalClassifierPerNode",
"MultiLabelLocalClassifierPerParentNode",
]
Loading
Loading