Skip to content

Commit

Permalink
Add a test to trigger the bug
Browse files Browse the repository at this point in the history
Signed-off-by: Pierre Bartet <[email protected]>
  • Loading branch information
Pierre Bartet authored and Pierre-Bartet committed Jan 23, 2025
1 parent e0799d3 commit 77310d5
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/test_sklearn_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,32 @@ def test_pipeline_make_column_selector(self):
)
assert_almost_equal(expected, got[0])

@unittest.skipIf(TARGET_OPSET < 11, reason="SequenceConstruct not available")
@unittest.skipIf(not check_scikit_version(), reason="Scikit 0.21 too old")
@ignore_warnings(category=(FutureWarning, UserWarning))
def test_pipeline_empty_make_column_selector(self):
X = pandas.DataFrame(
{"city": ["London", "London", "Paris", "Sallisaw"]}
)

ct = make_column_transformer(
(StandardScaler(), make_column_selector(dtype_include=numpy.number)),
(OneHotEncoder(), make_column_selector(dtype_include=object)),
)
expected = ct.fit_transform(X)
onx = to_onnx(ct, X, target_opset=TARGET_OPSET)
sess = InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
names = [i.name for i in sess.get_inputs()]
got = sess.run(
None,
{
names[0]: X[names[0]].values.reshape((-1, 1)),
},
)
assert_almost_equal(expected, got[0])

@unittest.skipIf(not check_scikit_version(), reason="Scikit 0.21 too old")
def test_feature_selector_no_converter(self):
class ColumnSelector(TransformerMixin, BaseEstimator):
Expand Down

0 comments on commit 77310d5

Please sign in to comment.