Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explainer api for local classifiers #102

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
62c218d
added initial implementation of explainer api for lcppn
Jan 9, 2024
ea1fff8
fixed lints
Jan 10, 2024
c4d75c5
fixed lints
Jan 14, 2024
299af62
added an _explain_lcppn implementation and some tests provided
iwan-tee Jan 14, 2024
0ea8956
modified docstrings
Jan 14, 2024
1efd946
explainer for lcpn implemented + tests added and some cases fixed
iwan-tee Jan 14, 2024
7dcb52f
Merge branch 'explainer_api_lcpn' into explainer_api
iwan-tee Jan 14, 2024
1de360c
tests added + some bugs fixed
iwan-tee Jan 15, 2024
933b1f6
base
iwan-tee Jan 15, 2024
c57abed
basic implementation
iwan-tee Jan 18, 2024
a829ce0
LCPL explanator implementation + test
iwan-tee Jan 23, 2024
33f2cbc
added tests for hierarchy without roots
Jan 26, 2024
c06d8a7
check on root node added
iwan-tee Jan 26, 2024
c597fce
minor updates
Jan 26, 2024
b79e5f4
codestyling
iwan-tee Jan 26, 2024
8a643f1
codestyling
iwan-tee Jan 26, 2024
606c1eb
Merge branch 'explainer_master' into explainer_api_lcpl
ashishpatel16 Jan 26, 2024
ca6c654
Update Explainer.py
ashishpatel16 Jan 26, 2024
9936dc3
Merge pull request #1 from ashishpatel16/explainer_api_lcpl
ashishpatel16 Jan 26, 2024
82573be
Merge pull request #2 from ashishpatel16/explainer_api_lcpn
ashishpatel16 Jan 26, 2024
d53e8d9
added support for xarray for lcppn
Jan 29, 2024
2449928
Merge branch 'explainer_master' into explainer_api
ashishpatel16 Jan 29, 2024
759489f
Update Explainer.py
ashishpatel16 Jan 29, 2024
0771c08
Update Explainer.py
ashishpatel16 Jan 29, 2024
4eb6f5c
fixed errors with classifier with single class
Jan 30, 2024
3955521
updated test cases and removed cached explainers
Feb 1, 2024
7c2f4d2
removed cached explainers
Feb 1, 2024
b12bdc3
modified predict proba to return dict
Feb 1, 2024
986b61c
Merge branch 'main' into explainer_api
ashishpatel16 Feb 2, 2024
eb11c0e
updated get_predict_proba to return only traversed prediction probabi…
Feb 3, 2024
8c700e4
updated fork
Feb 3, 2024
53a90a0
separate test file for explainer
Feb 3, 2024
5e74762
Update Explainer.py
ashishpatel16 Feb 5, 2024
b1f3656
_get_traversed_nodes edited
iwan-tee Feb 6, 2024
2a12087
fixed lints
Feb 12, 2024
84f6e39
fixed conflicts
Feb 12, 2024
b09f8da
refactored and cleaned up code
Feb 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions hiclass/LocalClassifierPerParentNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,18 @@ def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False):
nodes = self._get_parents()
self._fit_node_classifier(nodes, local_mode, use_joblib)

def _get_predict_proba(model, X):
classifiers = [
model.hierarchy_.nodes[node]["classifier"]
for node in model.hierarchy_.nodes
if "classifier" in model.hierarchy_.nodes[node]
def get_predict_proba(self, X):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jannikgro I have updated the function to return a dict, however it currently just returns predict_proba for all nodes in the hierarchy. I'm woking on another version which only gives the dict for only traversed nodes.

classifier_nodes = [
node
for node in self.hierarchy_.nodes
if "classifier" in self.hierarchy_.nodes[node]
]

# This will give list of target labels in the same order as predict_proba probabilities
# target_labels = [clf.classes_ for clf in classifiers]
predict_proba_per_classifier = [clf.predict_proba(X) for clf in classifiers]
return predict_proba_per_classifier
predict_proba_dict = {}
for node in classifier_nodes:
pred_probabilities = self.hierarchy_.nodes[node][
"classifier"
].predict_proba(X)
predict_proba_dict[node] = pred_probabilities

return predict_proba_dict
41 changes: 28 additions & 13 deletions tests/test_LocalClassifierPerParentNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,19 +238,20 @@ def test_explainer_tree_traversal(explainer_data):
x_train, x_test, y_train = explainer_data

lcppn.fit(x_train, y_train)

explainer = Explainer(lcppn, data=x_train, mode="tree")
shap_dict = explainer.explain_traversed_nodes(x_test)
print(shap_dict)

for key, val in shap_dict.items():
# Assert on shapes of shap values, must match (target_classes, num_samples, num_features)
model = lcppn.hierarchy_.nodes[key]["classifier"]
assert shap_dict[key].shape == (
len(model.classes_),
x_test.shape[0],
x_test.shape[1],
)
print(lcppn.get_predict_proba(x_test))

# explainer = Explainer(lcppn, data=x_train, mode="tree")
# shap_dict = explainer.explain_traversed_nodes(x_test)
# print(shap_dict)
#
# for key, val in shap_dict.items():
# # Assert on shapes of shap values, must match (target_classes, num_samples, num_features)
# model = lcppn.hierarchy_.nodes[key]["classifier"]
# assert shap_dict[key].shape == (
# len(model.classes_),
# x_test.shape[0],
# x_test.shape[1],
# )


# TODO: Add new test cases with hierarchies without root nodes
Expand Down Expand Up @@ -312,3 +313,17 @@ def test_explainer_tree_no_root(explainer_data_no_root):
x_test.shape[0],
x_test.shape[1],
)


def test_predict_proba(explainer_data):
rfc = RandomForestClassifier()
lcppn = LocalClassifierPerParentNode(
local_classifier=rfc, replace_classifiers=False
)

x_train, x_test, y_train = explainer_data

lcppn.fit(x_train, y_train)

pred_proba_dict = lcppn.get_predict_proba(x_test)
assert pred_proba_dict is not None