diff --git a/skl2onnx/shape_calculators/__init__.py b/skl2onnx/shape_calculators/__init__.py index ab5556b1e..71f9042cb 100644 --- a/skl2onnx/shape_calculators/__init__.py +++ b/skl2onnx/shape_calculators/__init__.py @@ -10,6 +10,7 @@ from . import cross_decomposition from . import dict_vectorizer from . import ensemble_shapes +from . import feature_selection from . import feature_hasher from . import flatten from . import function_transformer @@ -63,6 +64,7 @@ dict_vectorizer, ensemble_shapes, feature_hasher, + feature_selection, flatten, function_transformer, gaussian_process, diff --git a/skl2onnx/shape_calculators/concat.py b/skl2onnx/shape_calculators/concat.py index afdb6ab9e..9561c03af 100644 --- a/skl2onnx/shape_calculators/concat.py +++ b/skl2onnx/shape_calculators/concat.py @@ -93,13 +93,3 @@ def more_generic(t1, t2): register_shape_calculator("SklearnConcat", calculate_sklearn_concat) -register_shape_calculator("SklearnGenericUnivariateSelect", calculate_sklearn_concat) -register_shape_calculator("SklearnRFE", calculate_sklearn_concat) -register_shape_calculator("SklearnRFECV", calculate_sklearn_concat) -register_shape_calculator("SklearnSelectFdr", calculate_sklearn_concat) -register_shape_calculator("SklearnSelectFpr", calculate_sklearn_concat) -register_shape_calculator("SklearnSelectFromModel", calculate_sklearn_concat) -register_shape_calculator("SklearnSelectFwe", calculate_sklearn_concat) -register_shape_calculator("SklearnSelectKBest", calculate_sklearn_concat) -register_shape_calculator("SklearnSelectPercentile", calculate_sklearn_concat) -register_shape_calculator("SklearnVarianceThreshold", calculate_sklearn_concat) diff --git a/skl2onnx/shape_calculators/feature_selection.py b/skl2onnx/shape_calculators/feature_selection.py new file mode 100644 index 000000000..8c1164817 --- /dev/null +++ b/skl2onnx/shape_calculators/feature_selection.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 + + +from ..common._registration import register_shape_calculator +from ..common.utils import check_input_and_output_numbers + + +def calculate_sklearn_select(operator): + check_input_and_output_numbers(operator, output_count_range=1) + i = operator.inputs[0] + N = i.get_first_dimension() + C = operator.raw_operator.get_support().sum() + operator.outputs[0].type = i.type.__class__([N, C]) + + +register_shape_calculator("SklearnGenericUnivariateSelect", calculate_sklearn_select) +register_shape_calculator("SklearnRFE", calculate_sklearn_select) +register_shape_calculator("SklearnRFECV", calculate_sklearn_select) +register_shape_calculator("SklearnSelectFdr", calculate_sklearn_select) +register_shape_calculator("SklearnSelectFpr", calculate_sklearn_select) +register_shape_calculator("SklearnSelectFromModel", calculate_sklearn_select) +register_shape_calculator("SklearnSelectFwe", calculate_sklearn_select) +register_shape_calculator("SklearnSelectKBest", calculate_sklearn_select) +register_shape_calculator("SklearnSelectPercentile", calculate_sklearn_select) +register_shape_calculator("SklearnVarianceThreshold", calculate_sklearn_select) diff --git a/tests/test_sklearn_feature_selection_converters.py b/tests/test_sklearn_feature_selection_converters.py index 1feaa11e3..7ede71464 100644 --- a/tests/test_sklearn_feature_selection_converters.py +++ b/tests/test_sklearn_feature_selection_converters.py @@ -21,6 +21,9 @@ SelectPercentile, VarianceThreshold, ) +from sklearn.pipeline import make_pipeline +from sklearn.linear_model import LogisticRegression +from sklearn.preprocessing import StandardScaler from sklearn.svm import SVR from skl2onnx import convert_sklearn from skl2onnx.common.data_types import Int64TensorType, FloatTensorType @@ -30,11 +33,33 @@ class TestSklearnFeatureSelectionConverters(unittest.TestCase): def test_generic_univariate_select_int(self): model = GenericUnivariateSelect() + X = np.array( [[1, 2, 3, 1], [0, 3, 1, 4], [3, 5, 6, 1], [1, 2, 1, 5]], dtype=np.int64 ) y = np.array([0, 1, 0, 1]) model.fit(X, y) + + model_onnx = convert_sklearn( + model, + "generic univariate select", + [("input", Int64TensorType([None, X.shape[1]]))], + target_opset=TARGET_OPSET, + ) + self.assertTrue(model_onnx is not None) + dump_data_and_model( + X, model, model_onnx, basename="SklearnGenericUnivariateSelect" + ) + + def test_generic_univariate_select_kbest_int(self): + model = GenericUnivariateSelect(mode="k_best", param=2) + + X = np.array( + [[1, 2, 3, 1], [0, 3, 1, 4], [3, 5, 6, 1], [1, 2, 1, 5]], dtype=np.int64 + ) + y = np.array([0, 1, 0, 1]) + model.fit(X, y) + model_onnx = convert_sklearn( model, "generic univariate select", @@ -340,6 +365,23 @@ def test_select_k_best_float(self): self.assertTrue(model_onnx is not None) dump_data_and_model(X, model, model_onnx, basename="SklearnSelectKBest") + def test_select_k_best_scaler_logistic_regression_pipeline_float(self): + model = make_pipeline(SelectKBest(k=3), StandardScaler(), LogisticRegression()) + + X = np.array( + [[1, 2, 3, 1], [0, 3, 1, 4], [3, 5, 6, 1], [1, 2, 1, 5]], dtype=np.float32 + ) + y = np.array([0, 1, 0, 1]) + model.fit(X, y) + model_onnx = convert_sklearn( + model, + "select k best", + [("input", FloatTensorType([None, X.shape[1]]))], + target_opset=TARGET_OPSET, + ) + self.assertTrue(model_onnx is not None) + dump_data_and_model(X, model, model_onnx, basename="SklearnSelectKBest") + def test_select_percentile_float(self): model = SelectPercentile() X = np.array(