Skip to content

Commit

Permalink
feat: SHAP value support for XGBoost's binary classification models (#…
Browse files Browse the repository at this point in the history
…1660)

* dispatch SHAP settings for XGBoost clsf

* enable XGBoost Classification SHAP check

* Update checks for SHAP binary classification

* add pred_contribs/pred_interactions keyword support

* fix classification SHAP value tests

* fix some test failures

* include daal_check_version

* remove circular import

* forgotten evaluation

* disable tests for older onedal versions

* change tolerances

* change correct tolerances

* return to original design

* fix for 2024.7

* modify tests

* forgotten formatting

---------

Co-authored-by: icfaust <[email protected]>
  • Loading branch information
ahuber21 and icfaust authored Sep 6, 2024
1 parent 45fc83d commit 48714b0
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 20 deletions.
72 changes: 68 additions & 4 deletions daal4py/mb/model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def _convert_model(self, model):
else:
raise TypeError(f"Unknown model format {submodule_name}.{class_name}")

def _predict_classification(self, X, fptype, resultsToEvaluate):
def _predict_classification(
self, X, fptype, resultsToEvaluate, pred_contribs=False, pred_interactions=False
):
if X.shape[1] != self.n_features_in_:
raise ValueError("Shape of input is different from what was seen in `fit`")

Expand All @@ -203,8 +205,34 @@ def _predict_classification(self, X, fptype, resultsToEvaluate):
)

# Prediction
try:
return self._predict_classification_with_results_to_compute(
X, fptype, resultsToEvaluate, pred_contribs, pred_interactions
)
except TypeError as e:
if "unexpected keyword argument 'resultsToCompute'" in str(e):
if pred_contribs or pred_interactions:
# SHAP values requested, but not supported by this version
raise TypeError(
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} not supported by this version of daal4py"
) from e
else:
# unknown type error
raise
except RuntimeError as e:
if "Method is not implemented" in str(e):
if pred_contribs or pred_interactions:
raise NotImplementedError(
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} is not implemented for classification models"
)
else:
raise

# fallback to calculation without `resultsToCompute`
predict_algo = d4p.gbt_classification_prediction(
fptype=fptype, nClasses=self.n_classes_, resultsToEvaluate=resultsToEvaluate
nClasses=self.n_classes_,
fptype=fptype,
resultsToEvaluate=resultsToEvaluate,
)
predict_result = predict_algo.compute(X, self.daal_model_)

Expand All @@ -213,6 +241,40 @@ def _predict_classification(self, X, fptype, resultsToEvaluate):
else:
return predict_result.probabilities

def _predict_classification_with_results_to_compute(
self,
X,
fptype,
resultsToEvaluate,
pred_contribs=False,
pred_interactions=False,
):
"""Assume daal4py supports the resultsToCompute kwarg"""
resultsToCompute = ""
if pred_contribs:
resultsToCompute = "shapContributions"
elif pred_interactions:
resultsToCompute = "shapInteractions"

predict_algo = d4p.gbt_classification_prediction(
nClasses=self.n_classes_,
fptype=fptype,
resultsToCompute=resultsToCompute,
resultsToEvaluate=resultsToEvaluate,
)
predict_result = predict_algo.compute(X, self.daal_model_)

if pred_contribs:
return predict_result.prediction.ravel().reshape((-1, X.shape[1] + 1))
elif pred_interactions:
return predict_result.prediction.ravel().reshape(
(-1, X.shape[1] + 1, X.shape[1] + 1)
)
elif resultsToEvaluate == "computeClassLabels":
return predict_result.prediction.ravel().astype(np.int64, copy=False)
else:
return predict_result.probabilities

def _predict_regression(
self, X, fptype, pred_contribs=False, pred_interactions=False
):
Expand Down Expand Up @@ -278,11 +340,13 @@ def predict(self, X, pred_contribs=False, pred_interactions=False):
if self._is_regression:
return self._predict_regression(X, fptype, pred_contribs, pred_interactions)
else:
if pred_contribs or pred_interactions:
if (pred_contribs or pred_interactions) and self.model_type != "xgboost":
raise NotImplementedError(
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} is not implemented for classification models"
)
return self._predict_classification(X, fptype, "computeClassLabels")
return self._predict_classification(
X, fptype, "computeClassLabels", pred_contribs, pred_interactions
)

def _check_proba(self):
return not self._is_regression
Expand Down
16 changes: 11 additions & 5 deletions daal4py/sklearn/ensemble/GBTDAAL.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def fit(self, X, y):
# Return the classifier
return self

def _predict(self, X, resultsToEvaluate):
def _predict(
self, X, resultsToEvaluate, pred_contribs=False, pred_interactions=False
):
# Input validation
if not self.allow_nan_:
X = check_array(X, dtype=[np.single, np.double])
Expand All @@ -208,17 +210,21 @@ def _predict(self, X, resultsToEvaluate):
return np.full(X.shape[0], self.classes_[0])

fptype = getFPType(X)
predict_result = self._predict_classification(X, fptype, resultsToEvaluate)
predict_result = self._predict_classification(
X, fptype, resultsToEvaluate, pred_contribs, pred_interactions
)

if resultsToEvaluate == "computeClassLabels":
if resultsToEvaluate == "computeClassLabels" and not (
pred_contribs or pred_interactions
):
# Decode labels
le = preprocessing.LabelEncoder()
le.classes_ = self.classes_
return le.inverse_transform(predict_result)
return predict_result

def predict(self, X):
return self._predict(X, "computeClassLabels")
def predict(self, X, pred_contribs=False, pred_interactions=False):
return self._predict(X, "computeClassLabels", pred_contribs, pred_interactions)

def predict_proba(self, X):
return self._predict(X, "computeClassProbabilities")
Expand Down
105 changes: 94 additions & 11 deletions tests/test_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,14 @@


shap_required_version = (2024, "P", 1)
shap_api_change_version = (2025, "P", 0)
shap_supported = daal_check_version(shap_required_version)
shap_api_changed = daal_check_version(shap_api_change_version)
shap_not_supported_str = (
f"SHAP value calculation only supported for version {shap_required_version} or later"
)
shap_unavailable_str = "SHAP Python package not available"
shap_api_change_str = "SHAP calculation requires 2025.0 API"
cb_unavailable_str = "CatBoost not available"

# CatBoost's SHAP value calculation seems to be buggy
Expand Down Expand Up @@ -208,15 +211,15 @@ def test_model_predict_shap_contribs_missing_values(self):
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=5e-6)


# duplicate all tests for bae_score=0.0
# duplicate all tests for base_score=0.0
@unittest.skipUnless(shap_supported, reason=shap_not_supported_str)
class XGBoostRegressionModelBuilder_base_score0(XGBoostRegressionModelBuilder):
@classmethod
def setUpClass(cls):
XGBoostRegressionModelBuilder.setUpClass(0)


# duplicate all tests for bae_score=100
# duplicate all tests for base_score=100
@unittest.skipUnless(shap_supported, reason=shap_not_supported_str)
class XGBoostRegressionModelBuilder_base_score100(XGBoostRegressionModelBuilder):
@classmethod
Expand All @@ -235,7 +238,7 @@ def setUpClass(cls, base_score=0.5, n_classes=2, objective="binary:logistic"):
n_samples=500,
n_classes=n_classes,
n_features=n_features,
n_informative=10,
n_informative=(2 * n_features) // 3,
random_state=42,
)
cls.X_test = X[:2, :]
Expand Down Expand Up @@ -282,25 +285,59 @@ def test_missing_value_support(self):
def test_model_predict_shap_contribs(self):
booster = self.xgb_model.get_booster()
m = d4p.mb.convert_model(booster)
with self.assertRaises(NotImplementedError):
m.predict(self.X_test, pred_contribs=True)
if not shap_api_changed:
with self.assertRaises(NotImplementedError):
m.predict(self.X_test, pred_contribs=True)
elif self.n_classes > 2:
with self.assertRaisesRegex(
RuntimeError, "Multiclass classification SHAP values not supported"
):
m.predict(self.X_test, pred_contribs=True)
else:
d4p_pred = m.predict(self.X_test, pred_contribs=True)
xgboost_pred = booster.predict(
xgb.DMatrix(self.X_test),
pred_contribs=True,
approx_contribs=False,
validate_features=False,
)
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=1e-5)

def test_model_predict_shap_interactions(self):
booster = self.xgb_model.get_booster()
m = d4p.mb.convert_model(booster)
with self.assertRaises(NotImplementedError):
m.predict(self.X_test, pred_contribs=True)


# duplicate all tests for bae_score=0.3
if not shap_api_changed:
with self.assertRaises(NotImplementedError):
m.predict(self.X_test, pred_contribs=True)
elif self.n_classes > 2:
with self.assertRaisesRegex(
RuntimeError, "Multiclass classification SHAP values not supported"
):
m.predict(self.X_test, pred_interactions=True)
else:
d4p_pred = m.predict(self.X_test, pred_interactions=True)
xgboost_pred = booster.predict(
xgb.DMatrix(self.X_test),
pred_interactions=True,
approx_contribs=False,
validate_features=False,
)
# hitting floating precision limits for classification where class probabilities
# are between 0 and 1
# we need to accept large relative differences, as long as the absolute difference
# remains small (<1e-6)
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=5e-2, atol=1e-6)


# duplicate all tests for base_score=0.3
@unittest.skipUnless(shap_supported, reason=shap_not_supported_str)
class XGBoostClassificationModelBuilder_base_score03(XGBoostClassificationModelBuilder):
@classmethod
def setUpClass(cls):
XGBoostClassificationModelBuilder.setUpClass(base_score=0.3)


# duplicate all tests for bae_score=0.7
# duplicate all tests for base_score=0.7
@unittest.skipUnless(shap_supported, reason=shap_not_supported_str)
class XGBoostClassificationModelBuilder_base_score07(XGBoostClassificationModelBuilder):
@classmethod
Expand Down Expand Up @@ -328,6 +365,16 @@ def setUpClass(cls):
class XGBoostClassificationModelBuilder_objective_logitraw(
XGBoostClassificationModelBuilder
):
"""
Caveat: logitraw is not per se supported in daal4py because we always
1. apply the bias
2. normalize to probabilities ("activation") using sigmoid
(exception: SHAP values, the scores defining phi_ij are the raw class scores)
However, by undoing the activation and bias we can still compare if the original probas and SHAP values are aligned.
"""

@classmethod
def setUpClass(cls):
XGBoostClassificationModelBuilder.setUpClass(
Expand All @@ -352,6 +399,42 @@ def test_model_predict_proba(self):
# accept an rtol of 1e-5
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=1e-5)

@unittest.skipUnless(shap_api_changed, reason=shap_api_change_str)
def test_model_predict_shap_contribs(self):
booster = self.xgb_model.get_booster()
with self.assertWarns(UserWarning):
# expect a warning that logitraw behaves differently and/or
# that base_score is ignored / fixed to 0.5
m = d4p.mb.convert_model(self.xgb_model.get_booster())
d4p_pred = m.predict(self.X_test, pred_contribs=True)
xgboost_pred = booster.predict(
xgb.DMatrix(self.X_test),
pred_contribs=True,
approx_contribs=False,
validate_features=False,
)
# undo bias
d4p_pred[:, -1] += 0.5
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=5e-6)

@unittest.skipUnless(shap_api_changed, reason=shap_api_change_str)
def test_model_predict_shap_interactions(self):
booster = self.xgb_model.get_booster()
with self.assertWarns(UserWarning):
# expect a warning that logitraw behaves differently and/or
# that base_score is ignored / fixed to 0.5
m = d4p.mb.convert_model(self.xgb_model.get_booster())
d4p_pred = m.predict(self.X_test, pred_interactions=True)
xgboost_pred = booster.predict(
xgb.DMatrix(self.X_test),
pred_interactions=True,
approx_contribs=False,
validate_features=False,
)
# undo bias
d4p_pred[:, -1, -1] += 0.5
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=5e-5)


@unittest.skipUnless(shap_supported, reason=shap_not_supported_str)
class LightGBMRegressionModelBuilder(unittest.TestCase):
Expand Down

0 comments on commit 48714b0

Please sign in to comment.