diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index a762e967..902a5d92 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -128,7 +128,7 @@ def _explain_lcpn(self, X): local_explainer = self.explainers[node] # Calculate SHAP values for the given sample X - shap_values = local_explainer.shap_values(X) + shap_values = np.array(local_explainer.shap_values(X)) shap_values_dict[node] = shap_values return shap_values_dict diff --git a/tests/test_LocalClassifierPerNode.py b/tests/test_LocalClassifierPerNode.py index 1596cc71..2e7fb9fb 100644 --- a/tests/test_LocalClassifierPerNode.py +++ b/tests/test_LocalClassifierPerNode.py @@ -227,3 +227,63 @@ def test_explainer_not_empty(): explainer = Explainer(lcpn, data=X, mode="tree") shap_dict = explainer.explain(X_test) assert shap_dict is not None + + +@pytest.fixture +def explainer_data(): + # a + # / \ + # b c + # / \ / \ + # d e f g + x_train = np.random.randn(4, 3) + y_train = np.array( + [["a", "b", "d"], ["a", "b", "e"], ["a", "c", "f"], ["a", "c", "g"]] + ) + x_test = np.random.randn(5, 3) + + return x_train, x_test, y_train + + +def test_explainer_tree(explainer_data): + rfc = RandomForestClassifier() + lcpn = LocalClassifierPerNode( + local_classifier=rfc, + ) + + x_train, x_test, y_train = explainer_data + print(explainer_data) + + lcpn.fit(x_train, y_train) + + lcpn.predict(x_test) + explainer = Explainer(lcpn, data=x_train, mode="tree") + shap_dict = explainer.explain(x_test) + + for key, val in shap_dict.items(): + # Assert on shapes of shap values, must match (target_classes, num_samples, num_features) + model = lcpn.hierarchy_.nodes[key]["classifier"] + assert shap_dict[key].shape == ( + len(model.classes_), + x_test.shape[0], + x_test.shape[1], + ) + + +def test_explainer_linear(explainer_data): + logreg = LogisticRegression() + lcpn = LocalClassifierPerNode( + local_classifier=logreg, + ) + + x_train, x_test, y_train = explainer_data + lcpn.fit(x_train, y_train) + + lcpn.predict(x_test) + explainer = Explainer(lcpn, data=x_train, mode="linear") + shap_dict = explainer.explain(x_test) + + for key, val in shap_dict.items(): + # Assert on shapes of shap values, must match (num_samples, num_features) Note: Logistic regression is based + # on sigmoid and not softmax, hence there are no separate predictions for each target class + assert shap_dict[key].shape == x_test.shape \ No newline at end of file