Skip to content

Commit

Permalink
Fix the problem?
Browse files Browse the repository at this point in the history
Signed-off-by: Pierre Bartet <[email protected]>
  • Loading branch information
Pierre Bartet committed Jan 24, 2025
1 parent 483c6e3 commit 38d4e19
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions skl2onnx/shape_calculators/array_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,12 @@ def calculate_sklearn_array_feature_extractor(operator):
register_shape_calculator(
"SklearnArrayFeatureExtractor", calculate_sklearn_array_feature_extractor
)

def calculate_sklearn_select_k_best(operator):
check_input_and_output_numbers(operator, output_count_range=1)
i = operator.inputs[0]
N = i.get_first_dimension()
C = operator.raw_operator._get_support_mask().sum()
operator.outputs[0].type = i.type.__class__([N, C])

register_shape_calculator("SklearnSelectKBest", calculate_sklearn_select_k_best)
2 changes: 1 addition & 1 deletion skl2onnx/shape_calculators/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,6 @@ def more_generic(t1, t2):
register_shape_calculator("SklearnSelectFpr", calculate_sklearn_concat)
register_shape_calculator("SklearnSelectFromModel", calculate_sklearn_concat)
register_shape_calculator("SklearnSelectFwe", calculate_sklearn_concat)
register_shape_calculator("SklearnSelectKBest", calculate_sklearn_concat)
# register_shape_calculator("SklearnSelectKBest", calculate_sklearn_concat)
register_shape_calculator("SklearnSelectPercentile", calculate_sklearn_concat)
register_shape_calculator("SklearnVarianceThreshold", calculate_sklearn_concat)

0 comments on commit 38d4e19

Please sign in to comment.