Skip to content

Commit

Permalink
Fix all selectors
Browse files Browse the repository at this point in the history
Signed-off-by: Pierre Bartet <[email protected]>
  • Loading branch information
Pierre Bartet committed Jan 24, 2025
1 parent f92052f commit 00e9045
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 21 deletions.
2 changes: 2 additions & 0 deletions skl2onnx/shape_calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from . import cross_decomposition
from . import dict_vectorizer
from . import ensemble_shapes
from . import feature_selection
from . import feature_hasher
from . import flatten
from . import function_transformer
Expand Down Expand Up @@ -63,6 +64,7 @@
dict_vectorizer,
ensemble_shapes,
feature_hasher,
feature_selection,
flatten,
function_transformer,
gaussian_process,
Expand Down
11 changes: 0 additions & 11 deletions skl2onnx/shape_calculators/array_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,3 @@ def calculate_sklearn_array_feature_extractor(operator):
register_shape_calculator(
"SklearnArrayFeatureExtractor", calculate_sklearn_array_feature_extractor
)


def calculate_sklearn_select_k_best(operator):
check_input_and_output_numbers(operator, output_count_range=1)
i = operator.inputs[0]
N = i.get_first_dimension()
C = operator.raw_operator._get_support_mask().sum()
operator.outputs[0].type = i.type.__class__([N, C])


register_shape_calculator("SklearnSelectKBest", calculate_sklearn_select_k_best)
10 changes: 0 additions & 10 deletions skl2onnx/shape_calculators/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,3 @@ def more_generic(t1, t2):


register_shape_calculator("SklearnConcat", calculate_sklearn_concat)
register_shape_calculator("SklearnGenericUnivariateSelect", calculate_sklearn_concat)
register_shape_calculator("SklearnRFE", calculate_sklearn_concat)
register_shape_calculator("SklearnRFECV", calculate_sklearn_concat)
register_shape_calculator("SklearnSelectFdr", calculate_sklearn_concat)
register_shape_calculator("SklearnSelectFpr", calculate_sklearn_concat)
register_shape_calculator("SklearnSelectFromModel", calculate_sklearn_concat)
register_shape_calculator("SklearnSelectFwe", calculate_sklearn_concat)
# register_shape_calculator("SklearnSelectKBest", calculate_sklearn_concat)
register_shape_calculator("SklearnSelectPercentile", calculate_sklearn_concat)
register_shape_calculator("SklearnVarianceThreshold", calculate_sklearn_concat)
25 changes: 25 additions & 0 deletions skl2onnx/shape_calculators/feature_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0


from ..common._registration import register_shape_calculator
from ..common.utils import check_input_and_output_numbers


def calculate_sklearn_select(operator):
check_input_and_output_numbers(operator, output_count_range=1)
i = operator.inputs[0]
N = i.get_first_dimension()
C = operator.raw_operator._get_support_mask().sum()
operator.outputs[0].type = i.type.__class__([N, C])


register_shape_calculator("SklearnGenericUnivariateSelect", calculate_sklearn_select)
register_shape_calculator("SklearnRFE", calculate_sklearn_select)
register_shape_calculator("SklearnRFECV", calculate_sklearn_select)
register_shape_calculator("SklearnSelectFdr", calculate_sklearn_select)
register_shape_calculator("SklearnSelectFpr", calculate_sklearn_select)
register_shape_calculator("SklearnSelectFromModel", calculate_sklearn_select)
register_shape_calculator("SklearnSelectFwe", calculate_sklearn_select)
register_shape_calculator("SklearnSelectKBest", calculate_sklearn_select)
register_shape_calculator("SklearnSelectPercentile", calculate_sklearn_select)
register_shape_calculator("SklearnVarianceThreshold", calculate_sklearn_select)
2 changes: 2 additions & 0 deletions tests/test_sklearn_feature_selection_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
class TestSklearnFeatureSelectionConverters(unittest.TestCase):
def test_generic_univariate_select_int(self):
model = GenericUnivariateSelect()

X = np.array(
[[1, 2, 3, 1], [0, 3, 1, 4], [3, 5, 6, 1], [1, 2, 1, 5]], dtype=np.int64
)
y = np.array([0, 1, 0, 1])
model.fit(X, y)

model_onnx = convert_sklearn(
model,
"generic univariate select",
Expand Down

0 comments on commit 00e9045

Please sign in to comment.