Skip to content

Commit

Permalink
Fix empty column selector (#1159)
Browse files Browse the repository at this point in the history
* Add a test to trigger the bug

Signed-off-by: Pierre Bartet <[email protected]>

* Fix _parse.py

Signed-off-by: Pierre Bartet <[email protected]>

* Fix formatting

Signed-off-by: Pierre Bartet <[email protected]>

---------

Signed-off-by: Pierre Bartet <[email protected]>
Co-authored-by: Pierre Bartet <[email protected]>
  • Loading branch information
Pierre-Bartet and Pierre Bartet authored Jan 24, 2025
1 parent e0799d3 commit 469a18b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
5 changes: 5 additions & 0 deletions skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,11 @@ def _parse_sklearn_column_transformer(scope, model, inputs, custom_parsers=None)
elif isinstance(column_indices, (int, str)):
column_indices = [column_indices]
names = get_column_indices(column_indices, inputs, multiple=True)

# Skip transforms which apply to no columns at all
if len(names) == 0:
continue

transform_inputs = []
for onnx_var, onnx_is in names.items():
tr_inputs = _fetch_input_slice(scope, [inputs[onnx_var]], onnx_is)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_sklearn_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,30 @@ def test_pipeline_make_column_selector(self):
)
assert_almost_equal(expected, got[0])

@unittest.skipIf(TARGET_OPSET < 11, reason="SequenceConstruct not available")
@unittest.skipIf(not check_scikit_version(), reason="Scikit 0.21 too old")
@ignore_warnings(category=(FutureWarning, UserWarning))
def test_pipeline_empty_make_column_selector(self):
X = pandas.DataFrame({"city": ["London", "London", "Paris", "Sallisaw"]})

ct = make_column_transformer(
(StandardScaler(), make_column_selector(dtype_include=numpy.number)),
(OneHotEncoder(), make_column_selector(dtype_include=object)),
)
expected = ct.fit_transform(X)
onx = to_onnx(ct, X, target_opset=TARGET_OPSET)
sess = InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
names = [i.name for i in sess.get_inputs()]
got = sess.run(
None,
{
names[0]: X[names[0]].values.reshape((-1, 1)),
},
)
assert_almost_equal(expected, got[0])

@unittest.skipIf(not check_scikit_version(), reason="Scikit 0.21 too old")
def test_feature_selector_no_converter(self):
class ColumnSelector(TransformerMixin, BaseEstimator):
Expand Down

0 comments on commit 469a18b

Please sign in to comment.