diff --git a/skl2onnx/shape_calculators/__init__.py b/skl2onnx/shape_calculators/__init__.py index ab5556b1e..71f9042cb 100644 --- a/skl2onnx/shape_calculators/__init__.py +++ b/skl2onnx/shape_calculators/__init__.py @@ -10,6 +10,7 @@ from . import cross_decomposition from . import dict_vectorizer from . import ensemble_shapes +from . import feature_selection from . import feature_hasher from . import flatten from . import function_transformer @@ -63,6 +64,7 @@ dict_vectorizer, ensemble_shapes, feature_hasher, + feature_selection, flatten, function_transformer, gaussian_process, diff --git a/skl2onnx/shape_calculators/array_feature_extractor.py b/skl2onnx/shape_calculators/array_feature_extractor.py index 1e8959f26..e1ac00b09 100644 --- a/skl2onnx/shape_calculators/array_feature_extractor.py +++ b/skl2onnx/shape_calculators/array_feature_extractor.py @@ -16,14 +16,3 @@ def calculate_sklearn_array_feature_extractor(operator): register_shape_calculator( "SklearnArrayFeatureExtractor", calculate_sklearn_array_feature_extractor ) - - -def calculate_sklearn_select_k_best(operator): - check_input_and_output_numbers(operator, output_count_range=1) - i = operator.inputs[0] - N = i.get_first_dimension() - C = operator.raw_operator._get_support_mask().sum() - operator.outputs[0].type = i.type.__class__([N, C]) - - -register_shape_calculator("SklearnSelectKBest", calculate_sklearn_select_k_best) diff --git a/skl2onnx/shape_calculators/concat.py b/skl2onnx/shape_calculators/concat.py index f6049a9de..9561c03af 100644 --- a/skl2onnx/shape_calculators/concat.py +++ b/skl2onnx/shape_calculators/concat.py @@ -93,13 +93,3 @@ def more_generic(t1, t2): register_shape_calculator("SklearnConcat", calculate_sklearn_concat) -register_shape_calculator("SklearnGenericUnivariateSelect", calculate_sklearn_concat) -register_shape_calculator("SklearnRFE", calculate_sklearn_concat) -register_shape_calculator("SklearnRFECV", calculate_sklearn_concat) -register_shape_calculator("SklearnSelectFdr", calculate_sklearn_concat) -register_shape_calculator("SklearnSelectFpr", calculate_sklearn_concat) -register_shape_calculator("SklearnSelectFromModel", calculate_sklearn_concat) -register_shape_calculator("SklearnSelectFwe", calculate_sklearn_concat) -# register_shape_calculator("SklearnSelectKBest", calculate_sklearn_concat) -register_shape_calculator("SklearnSelectPercentile", calculate_sklearn_concat) -register_shape_calculator("SklearnVarianceThreshold", calculate_sklearn_concat) diff --git a/skl2onnx/shape_calculators/feature_selection.py b/skl2onnx/shape_calculators/feature_selection.py new file mode 100644 index 000000000..3f935ea3c --- /dev/null +++ b/skl2onnx/shape_calculators/feature_selection.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 + + +from ..common._registration import register_shape_calculator +from ..common.utils import check_input_and_output_numbers + + +def calculate_sklearn_select(operator): + check_input_and_output_numbers(operator, output_count_range=1) + i = operator.inputs[0] + N = i.get_first_dimension() + C = operator.raw_operator._get_support_mask().sum() + operator.outputs[0].type = i.type.__class__([N, C]) + + +register_shape_calculator("SklearnGenericUnivariateSelect", calculate_sklearn_select) +register_shape_calculator("SklearnRFE", calculate_sklearn_select) +register_shape_calculator("SklearnRFECV", calculate_sklearn_select) +register_shape_calculator("SklearnSelectFdr", calculate_sklearn_select) +register_shape_calculator("SklearnSelectFpr", calculate_sklearn_select) +register_shape_calculator("SklearnSelectFromModel", calculate_sklearn_select) +register_shape_calculator("SklearnSelectFwe", calculate_sklearn_select) +register_shape_calculator("SklearnSelectKBest", calculate_sklearn_select) +register_shape_calculator("SklearnSelectPercentile", calculate_sklearn_select) +register_shape_calculator("SklearnVarianceThreshold", calculate_sklearn_select) diff --git a/tests/test_sklearn_feature_selection_converters.py b/tests/test_sklearn_feature_selection_converters.py index 6d5d7c955..3a6c48fbe 100644 --- a/tests/test_sklearn_feature_selection_converters.py +++ b/tests/test_sklearn_feature_selection_converters.py @@ -33,11 +33,13 @@ class TestSklearnFeatureSelectionConverters(unittest.TestCase): def test_generic_univariate_select_int(self): model = GenericUnivariateSelect() + X = np.array( [[1, 2, 3, 1], [0, 3, 1, 4], [3, 5, 6, 1], [1, 2, 1, 5]], dtype=np.int64 ) y = np.array([0, 1, 0, 1]) model.fit(X, y) + model_onnx = convert_sklearn( model, "generic univariate select",