diff --git a/skglm/estimators.py b/skglm/estimators.py index cc488a422..8fb32e75b 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -967,6 +967,12 @@ class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstim alpha : float, default=1.0 Regularization strength; must be a positive float. + l1_ratio : float, default=1.0 + The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For + ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it + is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a + combination of L1 and L2. + tol : float, optional Stopping criterion for the optimization. @@ -1003,10 +1009,11 @@ class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstim Number of subproblems solved to reach the specified tolerance. """ - def __init__(self, alpha=1.0, tol=1e-4, max_iter=20, max_epochs=1_000, verbose=0, - fit_intercept=True, warm_start=False): + def __init__(self, alpha=1.0, l1_ratio=1.0, tol=1e-4, max_iter=20, max_epochs=1_000, + verbose=0, fit_intercept=True, warm_start=False): super().__init__() self.alpha = alpha + self.l1_ratio = l1_ratio self.tol = tol self.max_iter = max_iter self.max_epochs = max_epochs @@ -1035,7 +1042,8 @@ def fit(self, X, y): max_iter=self.max_iter, max_pn_iter=self.max_epochs, tol=self.tol, fit_intercept=self.fit_intercept, warm_start=self.warm_start, verbose=self.verbose) - return _glm_fit(X, y, self, Logistic(), L1(self.alpha), solver) + return _glm_fit(X, y, self, Logistic(), L1_plus_L2(self.alpha, self.l1_ratio), + solver) def predict_proba(self, X): """Probability estimates. diff --git a/skglm/tests/test_estimators.py b/skglm/tests/test_estimators.py index 0998cfbe5..ec7536f19 100644 --- a/skglm/tests/test_estimators.py +++ b/skglm/tests/test_estimators.py @@ -600,5 +600,25 @@ def test_GroupLasso_estimator_sparse_vs_dense(positive): np.testing.assert_allclose(coef_sparse, coef_dense, atol=1e-7, rtol=1e-5) +@pytest.mark.parametrize("X, l1_ratio", product([X, X_sparse], [1., 0.7, 0.])) +def test_SparseLogReg_elasticnet(X, l1_ratio): + + estimator_sk = clone(dict_estimators_sk['LogisticRegression']) + estimator_ours = clone(dict_estimators_ours['LogisticRegression']) + estimator_sk.set_params(fit_intercept=True, solver='saga', + penalty='elasticnet', l1_ratio=l1_ratio, max_iter=10_000) + estimator_ours.set_params(fit_intercept=True, l1_ratio=l1_ratio, max_iter=10_000) + + estimator_sk.fit(X, y) + estimator_ours.fit(X, y) + coef_sk = estimator_sk.coef_ + coef_ours = estimator_ours.coef_ + + np.testing.assert_array_less(1e-5, norm(coef_ours)) + np.testing.assert_allclose(coef_ours, coef_sk, atol=1e-6) + np.testing.assert_allclose( + estimator_sk.intercept_, estimator_ours.intercept_, rtol=1e-4) + + if __name__ == "__main__": pass