diff --git a/tests/test_sklearn_pipeline.py b/tests/test_sklearn_pipeline.py index 928d90886..620c7fc22 100644 --- a/tests/test_sklearn_pipeline.py +++ b/tests/test_sklearn_pipeline.py @@ -1335,6 +1335,32 @@ def test_pipeline_make_column_selector(self): ) assert_almost_equal(expected, got[0]) + @unittest.skipIf(TARGET_OPSET < 11, reason="SequenceConstruct not available") + @unittest.skipIf(not check_scikit_version(), reason="Scikit 0.21 too old") + @ignore_warnings(category=(FutureWarning, UserWarning)) + def test_pipeline_empty_make_column_selector(self): + X = pandas.DataFrame( + {"city": ["London", "London", "Paris", "Sallisaw"]} + ) + + ct = make_column_transformer( + (StandardScaler(), make_column_selector(dtype_include=numpy.number)), + (OneHotEncoder(), make_column_selector(dtype_include=object)), + ) + expected = ct.fit_transform(X) + onx = to_onnx(ct, X, target_opset=TARGET_OPSET) + sess = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + names = [i.name for i in sess.get_inputs()] + got = sess.run( + None, + { + names[0]: X[names[0]].values.reshape((-1, 1)), + }, + ) + assert_almost_equal(expected, got[0]) + @unittest.skipIf(not check_scikit_version(), reason="Scikit 0.21 too old") def test_feature_selector_no_converter(self): class ColumnSelector(TransformerMixin, BaseEstimator):