Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filter functions for explanations and shap_values #120

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5af9efb
basic implementation of filter_by_level function
iwan-tee Apr 14, 2024
5e52e86
basic implementation of filter_by_class function
iwan-tee Apr 14, 2024
5a8e9be
basic implementation of combine_filters function
iwan-tee Apr 14, 2024
cceb6d1
codestyling
iwan-tee Apr 14, 2024
ab1e3ab
helper function get_class_level added
iwan-tee Apr 14, 2024
719dd83
another helper functions added
iwan-tee Apr 14, 2024
343be04
some bugs fixed
iwan-tee Apr 14, 2024
af32acb
small plot_lcpl_explainer actualisation
iwan-tee Apr 14, 2024
3645e38
some small changes in plot_lcpl_explainer
iwan-tee Apr 14, 2024
d1cf3e0
another changes in plot_lcpl_explainer
iwan-tee Apr 14, 2024
7e6bcb3
functionality duplication removes + docstrings added (partially)
iwan-tee Apr 14, 2024
31e88b2
some actualization in plot_lcpl_explainer
iwan-tee Apr 14, 2024
849b30f
some refactoring
iwan-tee Apr 14, 2024
9877374
some refactoring
iwan-tee Apr 14, 2024
38a9060
some cases handled + tests written
iwan-tee Apr 15, 2024
159bd17
small changes
iwan-tee Apr 15, 2024
5faf54b
pydocstyle
iwan-tee Apr 15, 2024
c9e6aa2
documentation fixed
iwan-tee Apr 15, 2024
3d2bf3b
Algorithm overviem completed
iwan-tee Apr 15, 2024
a2e7700
plot_lcppn_explainer actualized
iwan-tee Apr 15, 2024
3b8239f
shap_multi_plot added
iwan-tee Apr 15, 2024
c06e5e6
part of the code substituted with newer methods
iwan-tee Apr 15, 2024
63c635f
pydocstyling
iwan-tee Apr 15, 2024
6953517
some cases handled and new tests added
iwan-tee Apr 15, 2024
e3a54ee
matplotlib requirement added
iwan-tee Apr 15, 2024
73f74f6
algorithm explaining updated
iwan-tee Apr 15, 2024
4ab46ad
small changes in indices selection
iwan-tee Apr 15, 2024
e431803
some skipiff added
iwan-tee Apr 16, 2024
bf64bfe
ray support added and used as a default (instead of joblib)
iwan-tee Apr 16, 2024
1047a2c
pydocs
iwan-tee Apr 16, 2024
c1dd746
pydocs
iwan-tee Apr 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ name = "pypi"
networkx = "*"
numpy = "*"
scikit-learn = "*"
matplotlib = "*"

[dev-packages]
pytest = "*"
Expand All @@ -20,4 +21,4 @@ sphinx-rtd-theme = "0.5.2"
[extras]
ray = "*"
shap = "0.44.1"
xarray = "*"
xarray = "*"
35 changes: 7 additions & 28 deletions docs/examples/plot_lcpl_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,14 @@

# Define Explainer
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test.values)
print(explanations)

# Let's filter the Shapley values corresponding to the Covid (level 1)
# and 'Respiratory' (level 0)
# Now, our task is to see how feature importance may vary from level to level
# We are going to calculate shap_values for 'Respiratory', 'Covid' and plot what we calculated
# This can be done with a single method .shap_multi_plot, which additionally returns calculated explanations

covid_idx = classifier.predict(X_test)[:, 1] == "Covid"

shap_filter_covid = {"level": 1, "class": "Covid", "sample": covid_idx}
shap_filter_resp = {"level": 0, "class": "Respiratory", "sample": covid_idx}
shap_val_covid = explanations.sel(**shap_filter_covid)
shap_val_resp = explanations.sel(**shap_filter_resp)


# This code snippet demonstrates how to visually compare the mean absolute SHAP values for 'Covid' vs. 'Respiratory' diseases.

# Feature names for the X-axis
feature_names = X_train.columns.values

# SHAP values for 'Covid'
shap_values_covid = shap_val_covid.shap_values.values

# SHAP values for 'Respiratory'
shap_values_resp = shap_val_resp.shap_values.values

shap.summary_plot(
[shap_values_covid, shap_values_resp],
features=X_test.iloc[covid_idx],
feature_names=X_train.columns.values,
plot_type="bar",
explanations = explainer.shap_multi_plot(
class_names=["Covid", "Respiratory"],
features=X_test.values,
pred_class="Respiratory",
features_names=X_train.columns.values,
)
18 changes: 10 additions & 8 deletions docs/examples/plot_lcppn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,25 @@
# Train local classifier per parent node
classifier.fit(X_train, Y_train)

# Get predictions
predictions = classifier.predict(X_test)

# Define Explainer
explainer = Explainer(classifier, data=X_train.values, mode="tree")
explanations = explainer.explain(X_test.values)
print(explanations)

# Filter samples which only predicted "Respiratory" at first level
respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory"

# Specify additional filters to obtain only level 0
shap_filter = {"level": 0, "class": "Respiratory", "sample": respiratory_idx}

# Use .sel() method to apply the filter and obtain filtered results
shap_val_respiratory = explanations.sel(shap_filter)
shap_val_respiratory = explainer.filter_by_class(
Copy link
Collaborator

@mirand863 mirand863 Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, I guess you can probably call the method get_sample_indices inside this other method filter_by_class, simplifying it for the user

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

settled in this option for now

explanations,
class_name="Respiratory",
sample_indices=explainer.get_sample_indices(predictions, "Respiratory"),
)


# Plot feature importance on test set
shap.plots.violin(
shap_val_respiratory.shap_values,
shap_val_respiratory,
feature_names=X_train.columns.values,
plot_size=(13, 8),
)
29 changes: 25 additions & 4 deletions docs/source/algorithms/explainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,42 @@ Code sample

lcppn.fit(x_train, y_train)
explainer = Explainer(lcppn, data=x_train, mode="tree")

# One of the possible ways to get explanations
explanations = explainer.explain(x_test)


++++++++++++++++++++++++++
Filtering and Manipulation
++++++++++++++++++++++++++

The Explanation object returned by the Explainer is built using the :literal:`xarray.Dataset` data structure, that enables the application of any xarray dataset operation. For example, filtering specific values can be quickly done. To illustrate the filtering operation, suppose we have SHAP values stored in the Explanation object named :literal:`explanation`.
When you work with the `Explanation` object generated by the `Explainer`, you're leveraging the power of the `xarray.Dataset`. This structure is not just robust but also flexible, allowing for comprehensive dataset operations—especially filtering.

**Practical Example: Filtering SHAP Values**

A common use case is to extract SHAP values for only the predicted nodes. In Local classifier per parent node approach, each node except the leaf nodes represents a classifier. Hence, to find the SHAP values, we can pass the prediction until the penultimate element to obtain the SHAP values.
To achieve this, we can use xarray's :literal:`.sel()` method:
Consider a scenario where you need to focus only on SHAP values corresponding to predicted nodes. In the context of our `LocalClassifierPerParentNode` model, each node—except for the leaf nodes—acts as a classifier. This setup is particularly useful when you're looking to isolate SHAP values up to the penultimate node in your predictions. Here’s how you can do this efficiently using the `sel()` method from xarray:

.. code-block:: python

# Creating a mask for selecting SHAP values for predicted classes
mask = {'class': lcppn.predict(x_test).flatten()[:-1]}
x = explanations.sel(mask).shap_values
selected_shap_values = explanations.sel(mask).shap_values

**Advanced Visualization: Multi-Plot SHAP Values**

For an even deeper analysis, you might want to visualize the SHAP values. The `shap_multi_plot()` method not only filters the data but also provides a visual representation of the SHAP values for specified classes. Below is an example that illustrates how to plot SHAP values for the classes "Covid" and "Respiratory":

.. code-block:: python

# Generating and plotting explanations for specific classes
explanations = explainer.shap_multi_plot(
class_names=["Covid", "Respiratory"],
features=x_test,
pred_class="Covid",
# Feature names specifiaction possible if x_train is a dataframe with specified columns_names
feature_names=x_train.columns.values
)



More advanced usage and capabilities can be found at the `Xarray.Dataset <https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html>`_ documentation.
Loading
Loading