Skip to content

Commit

Permalink
add another unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Dec 10, 2023
1 parent 8be9cd9 commit b7710a1
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tests/test_sklearn_ordinal_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,45 @@ def test_ordinal_encoder_pipeline_int64(self):
)
assert_almost_equal(expected, got[0].ravel())

@unittest.skipIf(
not ordinal_encoder_support(),
reason="OrdinalEncoder was not available before 0.20",
)
def test_ordinal_encoder_pipeline_string_int64(self):
from onnxruntime import InferenceSession

data = pd.DataFrame(
{"C1": ["cat2", "cat1", "cat3"], "C2": [1, 0, 1], "num": [0, 1, 1]}
)
data["num"] = data["num"].astype(np.float32)
y = np.array([0, 1, 2], dtype=np.float32)
preprocessor = ColumnTransformer(
transformers=[
("cat", OrdinalEncoder(dtype=np.int64), ["C1", "C2"]),
("num", "passthrough", ["num"]),
],
sparse_threshold=1,
verbose_feature_names_out=False,
).set_output(transform="pandas")
model = make_pipeline(
preprocessor, RandomForestRegressor(n_estimators=3, max_depth=2)
)
model.fit(data, y)
expected = model.predict(data)
model_onnx = to_onnx(model, data[:1], target_opset=TARGET_OPSET)
sess = InferenceSession(
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
got = sess.run(
None,
{
"C1": data["C1"].values.reshape((-1, 1)),
"C2": data["C2"].values.reshape((-1, 1)),
"num": data["num"].values.reshape((-1, 1)),
},
)
assert_almost_equal(expected, got[0].ravel())


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit b7710a1

Please sign in to comment.