diff --git a/doc/api.rst b/doc/api.rst index 25fc8ed8..2b5bdd9e 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -44,3 +44,9 @@ Robust robust.RobustWeightedClassifier robust.RobustWeightedRegressor robust.RobustWeightedKMeans + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + robust.make_huber_metric diff --git a/doc/changelog.rst b/doc/changelog.rst index c8e42339..27874f5f 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -4,6 +4,8 @@ Changelog Unreleased ---------- +- Add `make_huber_metric` which transform a non-robust to a robust metric using + Huber estimator. - Add a stopping criterion and parameter tuning heuristic for Huber robust mean estimator. - Add `CLARA` (Clustering for Large Applications) which extends k-medoids to diff --git a/doc/modules/robust.rst b/doc/modules/robust.rst index cdd47308..00845c41 100644 --- a/doc/modules/robust.rst +++ b/doc/modules/robust.rst @@ -142,6 +142,45 @@ This algorithm has been studied in the context of "mom" weights in the article [1]_, the context of "huber" weights has been mentioned in [2]_. Both weighting schemes can be seen as special cases of the algorithm in [3]_. + +Robust model selection +---------------------- +.. _make_huber_metric: + +One of the big challenge of robust machine learning is that the usual scoring +scheme (cross_validation with mean squared error for instance) is not robust. +Indeed, if the dataset has some outliers, then the test sets in cross-validation +may have outliers and then the cross_validation MSE would give us a huge error +for our robust algorithm on any corrupted data. + +To solve this problem, one can use robust score methods when doing +cross-validation using `make_huber_metric`. See the following example: + +:ref:`../auto_examples/robust/plot_robust_cv_example.html` + +This type of robust cross-validation was mentioned for instance in [4]_. + + +Here is what `make_huber_metric` computes: suppose that we compute a +loss function as such: + +.. math:: + + \widehat L = \frac{1}{n}\sum_{i=1}^n \ell(Y_i, f(X_i)) + +`make_huber_metric` propose to change this computation for + +.. math:: + \widehat L_{rob}=\widehat{\mathrm{Hub}}\left(\ell(Y_i, f(X_i))\right) + +where :math:`\widehat{\mathrm{Hub}}` is the Huber estimator of location. It is a +robust estimator of the mean (similar result can also be attained using the +trimmed mean), and :math:`\widehat{L}_{rob}` is robust in the sense +that an especially large value of :math:`\ell(Y_i, f(X_i))` would not change the +value of the result by a lot. The constant `c` used when tuning +:math:`\widehat{\mathrm{Hub}}` has the same role of tuning the robustness as in +the case of regression and classification using Huber weights. + Comparison with other robust estimators --------------------------------------- @@ -203,3 +242,7 @@ the example with California housing real dataset, for further discussion. .. [3] Stanislav Minsker and Timothée Mathieu. `"Excess risk bounds in robust empirical risk minimization" `_ arXiv preprint (2019). arXiv:1910.07485. + + .. [4] Elvezio Ronchetti , Christopher Field & Wade Blanchard + `" Robust Linear Model Selection by Cross-Validation" _ + Journal of the American Statistical Association (1995). diff --git a/examples/robust/plot_robust_cv_example.py b/examples/robust/plot_robust_cv_example.py new file mode 100644 index 00000000..19d66fd5 --- /dev/null +++ b/examples/robust/plot_robust_cv_example.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +""" +================================================================ +An example of a robust cross-validation evaluation in regression +================================================================ +In this example we compare `LinearRegression` (OLS) with `HuberRegressor` from +scikit-learn using cross-validation. + +We show that a robust cross-validation scheme gives a better +evaluation of the generalisation error in a corrupted dataset. + +In this example, we do robust cross-validation by using an alternative to the +empirical mean to aggregate the errors. This alternative is a robust estimator +of the mean (the trimmed mean is an example of such a robust estimator, but here +we use Huber's estimator). This robust estimator of the mean is used on each +fold of the cross-validation and then, we return the empirical mean of the +obtained robust scores to get the final score. +""" +print(__doc__) + +import numpy as np +from sklearn.metrics import mean_squared_error, make_scorer +from sklearn.model_selection import cross_val_score +from sklearn_extra.robust import make_huber_metric +from sklearn.linear_model import LinearRegression, HuberRegressor + +robust_mse = make_huber_metric(mean_squared_error, c=9) +rng = np.random.RandomState(42) + +X = rng.uniform(size=100)[:, np.newaxis] +y = 3 * X.ravel() +# Remark y <= 3 + +y[[42 // 2, 42, 42 * 2]] = 200 # outliers + +print("Non robust error:") +for reg in [LinearRegression(), HuberRegressor()]: + print( + reg, + " mse : %.2F" + % ( + np.mean( + cross_val_score( + reg, X, y, scoring=make_scorer(mean_squared_error) + ) + ) + ), + ) + + +print("\n") +print("Robust error:") +for reg in [LinearRegression(), HuberRegressor()]: + print( + reg, + " mse : %.2F" + % ( + np.mean( + cross_val_score(reg, X, y, scoring=make_scorer(robust_mse)) + ) + ), + ) diff --git a/sklearn_extra/robust/__init__.py b/sklearn_extra/robust/__init__.py index 640ad475..97861ee6 100644 --- a/sklearn_extra/robust/__init__.py +++ b/sklearn_extra/robust/__init__.py @@ -3,9 +3,12 @@ RobustWeightedKMeans, RobustWeightedRegressor, ) +from sklearn_extra.robust.mean_estimators import huber, make_huber_metric __all__ = [ "RobustWeightedClassifier", "RobustWeightedKMeans", "RobustWeightedRegressor", + "huber", + "make_huber_metric", ] diff --git a/sklearn_extra/robust/mean_estimators.py b/sklearn_extra/robust/mean_estimators.py index 42df44f0..e6040352 100644 --- a/sklearn_extra/robust/mean_estimators.py +++ b/sklearn_extra/robust/mean_estimators.py @@ -4,6 +4,8 @@ # License: BSD 3 clause import numpy as np +from scipy.stats import iqr +from sklearn.metrics import mean_squared_error def block_mom(X, k, random_state): @@ -88,7 +90,7 @@ def median_of_means(X, k, random_state=np.random.RandomState(42)): return median_of_means_blocked(x, blocks)[0] -def huber(X, c=None, T=20, tol=1e-3): +def huber(X, c=None, n_iter=20, tol=1e-3): """Compute the Huber estimator of location of X with parameter c Parameters @@ -104,7 +106,7 @@ def huber(X, c=None, T=20, tol=1e-3): if c is None, the interquartile range (IQR) is used as heuristic. - T : int, default = 20 + n_iter : int, default = 20 Number of iterations of the algorithm. tol : float, default=1e-3 @@ -138,7 +140,7 @@ def psisx(x, c): last_mu = mu # Run the iterative reweighting algorithm to compute M-estimator. - for t in range(T): + for t in range(n_iter): # Compute the weights w = psisx(x - mu, c_numeric) @@ -156,3 +158,70 @@ def psisx(x, c): last_mu = mu return mu + + +def make_huber_metric( + score_func=mean_squared_error, sample_weight=None, c=None, n_iter=20 +): + """ + Make a robust metric using Huber estimator. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + + score_func : callable + Score function (or loss function) with signature + ``score_func(y, y_pred, **kwargs)``. + + sample_weight: array-like of shape (n_samples,), default=None + Sample weights. + + + c : float >0, default = None + parameter that control the robustness of the estimator. + c going to zero gives a behavior close to the median. + c going to infinity gives a behavior close to sample mean. + if c is None, the iqr (inter quartile range) is used as heuristic. + + n_iter : int, default = 20 + Number of iterations of the algorithm. + + Return + ------ + + Robust metric function, a callable with signature + ``score_func(y, y_pred, **kwargs). + + Examples + -------- + + >>> import numpy as np + >>> from sklearn.metrics import mean_squared_error + >>> from sklearn_extra.robust import make_huber_metric + >>> robust_mse = make_huber_metric(mean_squared_error, c=5) + >>> y_true = np.hstack([np.zeros(98), 20*np.ones(2)]) # corrupted test values + >>> np.random.shuffle(y_true) # shuffle them + >>> y_pred = np.zeros(100) # predicted values + >>> result = robust_mse(y_true, y_pred) + """ + + def metric(y_true, y_pred): + # change size in order to use the raw multisample + # to have individual values + y1 = [y_true] + y2 = [y_pred] + values = score_func( + y1, y2, sample_weight=sample_weight, multioutput="raw_values" + ) + if c is None: + c_ = iqr(values) + else: + c_ = c + if c_ == 0: + return np.median(values) + else: + return huber(values, c_, n_iter) + + return metric diff --git a/sklearn_extra/robust/tests/test_mean_estimators.py b/sklearn_extra/robust/tests/test_mean_estimators.py index 0bce4dab..f93f0f6d 100644 --- a/sklearn_extra/robust/tests/test_mean_estimators.py +++ b/sklearn_extra/robust/tests/test_mean_estimators.py @@ -1,8 +1,14 @@ import numpy as np import pytest -from sklearn_extra.robust.mean_estimators import median_of_means, huber - +from sklearn_extra.robust.mean_estimators import ( + median_of_means, + huber, + make_huber_metric, +) +from sklearn.metrics import mean_squared_error, make_scorer +from sklearn.model_selection import cross_val_score +from sklearn.linear_model import HuberRegressor rng = np.random.RandomState(42) @@ -30,3 +36,29 @@ def test_huber(): mu = huber(X, c=0.5) assert len(record) == 0 assert np.abs(mu) < 0.1 + + +def test_robust_metric(): + robust_mse = make_huber_metric(mean_squared_error, c=5) + y_true = np.hstack([np.zeros(95), 20 * np.ones(5)]) + np.random.shuffle(y_true) + y_pred = np.zeros(100) + + assert robust_mse(y_true, y_pred) < 1 + + +def test_check_robust_cv(): + + robust_mse = make_huber_metric(mean_squared_error, c=9) + rng = np.random.RandomState(42) + + X = rng.uniform(size=100)[:, np.newaxis] + y = 3 * X.ravel() + + y[[42 // 2, 42, 42 * 2]] = 200 # outliers + + huber_reg = HuberRegressor() + error_Hub_reg = error_ols = np.mean( + cross_val_score(huber_reg, X, y, scoring=make_scorer(robust_mse)) + ) + assert error_Hub_reg < 1