Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explainer for local classifier per level #minor #116

Merged
merged 88 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
62c218d
added initial implementation of explainer api for lcppn
Jan 9, 2024
ea1fff8
fixed lints
Jan 10, 2024
c4d75c5
fixed lints
Jan 14, 2024
299af62
added an _explain_lcppn implementation and some tests provided
iwan-tee Jan 14, 2024
0ea8956
modified docstrings
Jan 14, 2024
1efd946
explainer for lcpn implemented + tests added and some cases fixed
iwan-tee Jan 14, 2024
7dcb52f
Merge branch 'explainer_api_lcpn' into explainer_api
iwan-tee Jan 14, 2024
1de360c
tests added + some bugs fixed
iwan-tee Jan 15, 2024
933b1f6
base
iwan-tee Jan 15, 2024
c57abed
basic implementation
iwan-tee Jan 18, 2024
a829ce0
LCPL explanator implementation + test
iwan-tee Jan 23, 2024
33f2cbc
added tests for hierarchy without roots
Jan 26, 2024
c06d8a7
check on root node added
iwan-tee Jan 26, 2024
c597fce
minor updates
Jan 26, 2024
b79e5f4
codestyling
iwan-tee Jan 26, 2024
8a643f1
codestyling
iwan-tee Jan 26, 2024
606c1eb
Merge branch 'explainer_master' into explainer_api_lcpl
ashishpatel16 Jan 26, 2024
ca6c654
Update Explainer.py
ashishpatel16 Jan 26, 2024
9936dc3
Merge pull request #1 from ashishpatel16/explainer_api_lcpl
ashishpatel16 Jan 26, 2024
82573be
Merge pull request #2 from ashishpatel16/explainer_api_lcpn
ashishpatel16 Jan 26, 2024
d53e8d9
added support for xarray for lcppn
Jan 29, 2024
2449928
Merge branch 'explainer_master' into explainer_api
ashishpatel16 Jan 29, 2024
759489f
Update Explainer.py
ashishpatel16 Jan 29, 2024
0771c08
Update Explainer.py
ashishpatel16 Jan 29, 2024
4eb6f5c
fixed errors with classifier with single class
Jan 30, 2024
3955521
updated test cases and removed cached explainers
Feb 1, 2024
7c2f4d2
removed cached explainers
Feb 1, 2024
b12bdc3
modified predict proba to return dict
Feb 1, 2024
986b61c
Merge branch 'main' into explainer_api
ashishpatel16 Feb 2, 2024
eb11c0e
updated get_predict_proba to return only traversed prediction probabi…
Feb 3, 2024
8c700e4
updated fork
Feb 3, 2024
53a90a0
separate test file for explainer
Feb 3, 2024
5e74762
Update Explainer.py
ashishpatel16 Feb 5, 2024
b1f3656
_get_traversed_nodes edited
iwan-tee Feb 6, 2024
2a12087
fixed lints
Feb 12, 2024
84f6e39
fixed conflicts
Feb 12, 2024
b09f8da
refactored and cleaned up code
Feb 12, 2024
aecdd96
updated test cases and isolated lcppn code
Feb 12, 2024
9a73b6c
added support for lcpn
Feb 16, 2024
c5b5a68
Merge branch 'main' into lcpn_explainer
ashishpatel16 Mar 14, 2024
41c17be
LCPL support added + test coverage
iwan-tee Mar 15, 2024
9f429e7
little details in test_Explainer
iwan-tee Mar 15, 2024
384120d
code standartization
iwan-tee Mar 27, 2024
03fd9ea
explainer redone + tests
iwan-tee Mar 27, 2024
f415020
little changes
iwan-tee Mar 27, 2024
93b611a
shap_installed check added
iwan-tee Mar 27, 2024
d68bcfd
decorators added
iwan-tee Mar 27, 2024
b477da3
some little changes
iwan-tee Mar 28, 2024
b2cbdd9
some conflicts resolving
iwan-tee Mar 28, 2024
863b000
read the docs added
iwan-tee Mar 28, 2024
9305879
some changes
iwan-tee Mar 28, 2024
84c5cd5
some edits in readthedocs
iwan-tee Mar 28, 2024
4ec3078
README file actualized
iwan-tee Apr 2, 2024
cfab6b8
shap requirement added
iwan-tee Apr 3, 2024
c4c3433
another dependencies added
iwan-tee Apr 3, 2024
d27aa5b
some codestyling
iwan-tee Apr 4, 2024
6f627f6
code standardization + conflicts resolving
iwan-tee Apr 4, 2024
d4abcfb
tests remastered
iwan-tee Apr 4, 2024
bb5a693
changes to resolve the conflicts
iwan-tee Apr 4, 2024
325a464
Merge branch 'main' into lcpl_explainer
iwan-tee Apr 4, 2024
4577a5c
codestyling
iwan-tee Apr 4, 2024
7f37044
problems in readthedocs resolved
iwan-tee Apr 4, 2024
05a2426
README edited (recovering deleted title)
iwan-tee Apr 8, 2024
8cf1376
plot_lcpl_explainer cleaned
iwan-tee Apr 8, 2024
1764ad2
plot_lcpl_explainer cleaning and formatting
iwan-tee Apr 8, 2024
ab255db
duplications removed
iwan-tee Apr 8, 2024
d7f3f65
datasets added
iwan-tee Apr 9, 2024
b9e14f1
Platypus dataset based example added
iwan-tee Apr 9, 2024
4e48e83
typo deleted
iwan-tee Apr 10, 2024
e3e3022
get_predict_proba function removed
iwan-tee Apr 10, 2024
5880731
code unification
iwan-tee Apr 10, 2024
6b73897
duplicating dependencies removed
iwan-tee Apr 10, 2024
7f3a129
another example added
iwan-tee Apr 10, 2024
f12c767
Merge branch 'main' into lcpl_explainer
iwan-tee Apr 10, 2024
bc973cb
LCPN in esplainer tests added
iwan-tee Apr 10, 2024
ec4ff48
get_traversed_nodes_lcpn added
iwan-tee Apr 10, 2024
5002d06
some bugs fixed and special cases added
iwan-tee Apr 10, 2024
622c6e8
some problems fixed
iwan-tee Apr 10, 2024
d683141
a typo fixed
iwan-tee Apr 10, 2024
8f0cb7f
typo fixed + another plot used
iwan-tee Apr 12, 2024
f69e6c0
tests fixed
iwan-tee Apr 12, 2024
29be63b
not needed import deleted
iwan-tee Apr 12, 2024
5e46e8b
not needed import deleted
iwan-tee Apr 12, 2024
3776429
not needed dependency deleted
iwan-tee Apr 12, 2024
bbce800
README file fixed
iwan-tee Apr 12, 2024
9f70469
datasets.py removed
iwan-tee Apr 12, 2024
7cc9ceb
outdated plot_data_filtering file deleted
iwan-tee Apr 12, 2024
a36261e
another plot added
iwan-tee Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ predictions = pipeline.predict(X_test)
```

## Explaining Hierarchical Classifiers
Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).

Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level](https://colab.research.google.com/drive/1VnGlJu-1wSG4wxHXL0Ijf2a7Pu3kklT-?usp=sharing) is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).

## Step-by-step walk-through

Expand Down
59 changes: 59 additions & 0 deletions docs/examples/plot_lcpl_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
"""
=========================================
Explaining Local Classifier Per Level
=========================================

A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPL model.
A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`.
SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here <https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv>`_.
"""
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerLevel, Explainer
import shap
from hiclass.datasets import load_platypus

# Load train and test splits
X_train, X_test, Y_train, Y_test = load_platypus()

# Use random forest classifiers for every level
rfc = RandomForestClassifier()
classifier = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False)

# Train local classifiers per level
classifier.fit(X_train, Y_train)

# Define Explainer
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test.values)
print(explanations)

# Let's filter the Shapley values corresponding to the Covid (level 1)
# and 'Respiratory' (level 0)

covid_idx = classifier.predict(X_test)[:, 1] == "Covid"

shap_filter_covid = {"level": 1, "class": "Covid", "sample": covid_idx}
shap_filter_resp = {"level": 0, "class": "Respiratory", "sample": covid_idx}
shap_val_covid = explanations.sel(**shap_filter_covid)
shap_val_resp = explanations.sel(**shap_filter_resp)


# This code snippet demonstrates how to visually compare the mean absolute SHAP values for 'Covid' vs. 'Respiratory' diseases.

# Feature names for the X-axis
feature_names = X_train.columns.values

# SHAP values for 'Covid'
shap_values_covid = shap_val_covid.shap_values.values

# SHAP values for 'Respiratory'
shap_values_resp = shap_val_resp.shap_values.values

shap.summary_plot(
[shap_values_covid, shap_values_resp],
features=X_test.iloc[covid_idx],
feature_names=X_train.columns.values,
plot_type="bar",
class_names=["Covid", "Respiratory"],
)
55 changes: 45 additions & 10 deletions hiclass/Explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,31 @@ def _get_traversed_nodes_lcpn(self, samples):

return traversals

def _get_traversed_nodes_lcpl(self, samples):
"""
Return a list of all traversed nodes as per the provided LocalClassifierPerLevel model.

Parameters
----------
samples : array-like
Sample data for which to generate traversed nodes.

Returns
-------
traversals : list
A list of all traversed nodes as per LocalClassifierPerLevel (LCPL) strategy.
"""
traversals = []
predictions = self.hierarchical_model.predict(samples)
for pred in predictions:
traversal_order = []
filtered_pred = [p for p in pred if p.strip()]
for i in range(1, len(filtered_pred) + 1):
node = self.hierarchical_model.separator_.join(filtered_pred[:i])
traversal_order.append(node)
traversals.append(traversal_order)
return traversals

def _calculate_shap_values(self, X):
"""
Return an xarray.Dataset object for a single sample provided. This dataset is aligned on the `level` attribute.
Expand All @@ -244,23 +269,27 @@ def _calculate_shap_values(self, X):
A single explanation for the prediction of given sample.
"""
traversed_nodes = []
if isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
if isinstance(self.hierarchical_model, LocalClassifierPerLevel):
traversed_nodes = self._get_traversed_nodes_lcpl(X)[0]
elif isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
traversed_nodes = self._get_traversed_nodes_lcppn(X)[0]
elif isinstance(self.hierarchical_model, LocalClassifierPerNode):
traversed_nodes = self._get_traversed_nodes_lcpn(X)[0]
datasets = []
level = 0
for node in traversed_nodes:
# Skip if node is empty or classifier is not found, can happen in case of imbalanced hierarchies
if (
node == ""
or "classifier" not in self.hierarchical_model.hierarchy_.nodes[node]
if node == "" or (
("classifier" not in self.hierarchical_model.hierarchy_.nodes[node])
and (not isinstance(self.hierarchical_model, LocalClassifierPerLevel))
):
continue

local_classifier = self.hierarchical_model.hierarchy_.nodes[node][
"classifier"
]
if isinstance(self.hierarchical_model, LocalClassifierPerLevel):
local_classifier = self.hierarchical_model.local_classifiers_[level]
else:
local_classifier = self.hierarchical_model.hierarchy_.nodes[node][
"classifier"
]

# Create a SHAP explainer for the local classifier
local_explainer = deepcopy(self.explainer)(local_classifier, self.data)
Expand All @@ -283,7 +312,7 @@ def _calculate_shap_values(self, X):
for label in local_classifier.classes_
]
predicted_class = current_node
else:
elif isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
simplified_labels = [
label.split(self.hierarchical_model.separator_)[-1]
for label in local_classifier.classes_
Expand All @@ -293,6 +322,12 @@ def _calculate_shap_values(self, X):
.flatten()[0]
.split(self.hierarchical_model.separator_)[-1]
)
else:
simplified_labels = [
label.split(self.hierarchical_model.separator_)[-1]
for label in local_classifier.classes_
]
predicted_class = current_node

classes = xr.DataArray(
simplified_labels,
Expand Down Expand Up @@ -326,7 +361,7 @@ def _calculate_shap_values(self, X):
"level": level,
}
)
level = level + 1
level += 1
datasets.append(local_dataset)
sample_explanation = xr.concat(datasets, dim="level")
return sample_explanation
55 changes: 51 additions & 4 deletions tests/test_Explainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import numpy as np
import pytest
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerNode, LocalClassifierPerParentNode, Explainer
from hiclass import (
LocalClassifierPerLevel,
LocalClassifierPerParentNode,
LocalClassifierPerNode,
Explainer,
)

try:
import shap
Expand Down Expand Up @@ -98,6 +103,26 @@ def test_explainer_tree_lcpn(data, request):
assert str(explanations["node"][i].data[j]) == y_pred[j]


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
def test_explainer_tree_lcpl(data, request):
rfc = RandomForestClassifier()
lcpl = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False)

x_train, x_test, y_train = request.getfixturevalue(data)

lcpl.fit(x_train, y_train)

explainer = Explainer(lcpl, data=x_train, mode="tree")
explanations = explainer.explain(x_test)
assert explanations is not None
y_preds = lcpl.predict(x_test)
for i in range(len(x_test)):
y_pred = y_preds[i]
for j in range(len(y_pred)):
assert str(explanations["node"][i].data[j]) == y_pred[j]


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
def test_traversal_path_lcppn(data, request):
Expand Down Expand Up @@ -142,11 +167,30 @@ def test_traversal_path_lcpn(data, request):
assert label == preds[i][j]


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
def test_traversal_path_lcpl(data, request):
x_train, x_test, y_train = request.getfixturevalue(data)
rfc = RandomForestClassifier()
lcpl = LocalClassifierPerLevel(local_classifier=rfc, replace_classifiers=False)

lcpl.fit(x_train, y_train)
explainer = Explainer(lcpl, data=x_train, mode="tree")
traversals = explainer._get_traversed_nodes_lcpl(x_test)
preds = lcpl.predict(x_test)
assert len(preds) == len(traversals)
for i in range(len(x_test)):
for j in range(len(traversals[i])):
label = traversals[i][j].split(lcpl.separator_)[-1]
assert label == preds[i][j]


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.skipif(not xarray_installed, reason="xarray not installed")
@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"])
@pytest.mark.parametrize(
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
"classifier",
[LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode],
)
def test_explain_with_xr(data, request, classifier):
x_train, x_test, y_train = request.getfixturevalue(data)
Expand All @@ -162,7 +206,8 @@ def test_explain_with_xr(data, request, classifier):


@pytest.mark.parametrize(
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
"classifier",
[LocalClassifierPerParentNode, LocalClassifierPerLevel, LocalClassifierPerNode],
)
def test_imports(classifier):
x_train = [[76, 12, 49], [88, 63, 31], [5, 42, 24], [17, 90, 55]]
Expand All @@ -176,8 +221,10 @@ def test_imports(classifier):
assert isinstance(explainer.data, np.ndarray)


@pytest.mark.skipif(not shap_installed, reason="shap not installed")
@pytest.mark.parametrize(
"classifier", [LocalClassifierPerParentNode, LocalClassifierPerNode]
"classifier",
[LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode],
)
@pytest.mark.parametrize("data", ["explainer_data"])
@pytest.mark.parametrize("mode", ["linear", "gradient", "deep", "tree", ""])
Expand Down
Loading