Skip to content

Commit

Permalink
tests added + some bugs fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
iwan-tee committed Jan 15, 2024
1 parent 7dcb52f commit 1de360c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
2 changes: 1 addition & 1 deletion hiclass/Explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _explain_lcpn(self, X):
local_explainer = self.explainers[node]

# Calculate SHAP values for the given sample X
shap_values = local_explainer.shap_values(X)
shap_values = np.array(local_explainer.shap_values(X))
shap_values_dict[node] = shap_values

return shap_values_dict
Expand Down
60 changes: 60 additions & 0 deletions tests/test_LocalClassifierPerNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,63 @@ def test_explainer_not_empty():
explainer = Explainer(lcpn, data=X, mode="tree")
shap_dict = explainer.explain(X_test)
assert shap_dict is not None


@pytest.fixture
def explainer_data():
# a
# / \
# b c
# / \ / \
# d e f g
x_train = np.random.randn(4, 3)
y_train = np.array(
[["a", "b", "d"], ["a", "b", "e"], ["a", "c", "f"], ["a", "c", "g"]]
)
x_test = np.random.randn(5, 3)

return x_train, x_test, y_train


def test_explainer_tree(explainer_data):
rfc = RandomForestClassifier()
lcpn = LocalClassifierPerNode(
local_classifier=rfc,
)

x_train, x_test, y_train = explainer_data
print(explainer_data)

lcpn.fit(x_train, y_train)

lcpn.predict(x_test)
explainer = Explainer(lcpn, data=x_train, mode="tree")
shap_dict = explainer.explain(x_test)

for key, val in shap_dict.items():
# Assert on shapes of shap values, must match (target_classes, num_samples, num_features)
model = lcpn.hierarchy_.nodes[key]["classifier"]
assert shap_dict[key].shape == (
len(model.classes_),
x_test.shape[0],
x_test.shape[1],
)


def test_explainer_linear(explainer_data):
logreg = LogisticRegression()
lcpn = LocalClassifierPerNode(
local_classifier=logreg,
)

x_train, x_test, y_train = explainer_data
lcpn.fit(x_train, y_train)

lcpn.predict(x_test)
explainer = Explainer(lcpn, data=x_train, mode="linear")
shap_dict = explainer.explain(x_test)

for key, val in shap_dict.items():
# Assert on shapes of shap values, must match (num_samples, num_features) Note: Logistic regression is based
# on sigmoid and not softmax, hence there are no separate predictions for each target class
assert shap_dict[key].shape == x_test.shape

0 comments on commit 1de360c

Please sign in to comment.