Skip to content

Commit

Permalink
Investigate issue 1129 (#1131)
Browse files Browse the repository at this point in the history
* add test to investigate

Signed-off-by: xadupre <[email protected]>

* Add unit test from issue 1129

Signed-off-by: xadupre <[email protected]>

* Update CI with the latest verions of onnx (#1130)

Signed-off-by: xadupre <[email protected]>

* fix missing provider

Signed-off-by: xadupre <[email protected]>

---------

Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre authored Jan 23, 2025
1 parent 1627bf1 commit e0799d3
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
56 changes: 56 additions & 0 deletions tests/test_issues_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,62 @@ def Classifier(features: list[str]) -> base.BaseEstimator:
)
assert modelengine is not None

def test_issue_1129_lr(self):

import numpy as np
from numpy.testing import assert_almost_equal
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
import skl2onnx
from onnxruntime import InferenceSession

# Create a small dataframe with 10 rows and 2 columns
np.random.seed(0)
data = {
"float_column": np.random.rand(10).astype(np.float64),
"int_column": np.random.randint(0, 100, size=10).astype(np.int64),
}
x_ = pd.DataFrame(data)
y = np.random.binomial(1, 0.5, size=10)

# Create a test dataset with 10 rows
test_data = {
"float_column": np.random.rand(10).astype(np.float64),
"int_column": np.random.randint(0, 100, size=10).astype(np.int64),
}
x_test_ = pd.DataFrame(test_data)

for cls in [LogisticRegression, DecisionTreeClassifier, RandomForestClassifier]:
with self.subTest(cls=cls):
# Select and train a model
if cls == LogisticRegression:
x = x_.astype(np.float64)
x_test = x_test_.astype(np.float64)
decimal = 10
else:
x = x_.astype(np.float32)
x_test = x_test_.astype(np.float32)
decimal = 4
model = cls()
model.fit(x, y)
# Take predictions and probabilities with sklearn
sklearn_preds = model.predict(x_test)
sklearn_probs = model.predict_proba(x_test)

# Convert the model to ONNX
onnx_model = skl2onnx.to_onnx(
model, x.values, options={"zipmap": False}
)
# Take predictions and probabilities with ONNX
sess = InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_prediction = sess.run(None, {"X": x_test.to_numpy()})
assert_almost_equal(sklearn_probs, onnx_prediction[1], decimal=decimal)
assert_almost_equal(sklearn_preds, onnx_prediction[0])


if __name__ == "__main__":
unittest.main(verbosity=2)
4 changes: 2 additions & 2 deletions tests/test_sklearn_pipeline_concat_tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,8 @@ def test_issue_712_svc_binary_empty(self):
target_opset=TARGET_OPSET,
options={CountVectorizer: {"keep_empty_string": True}},
)
with open("debug.onnx", "wb") as f:
f.write(onx.SerializeToString())
# with open("debug.onnx", "wb") as f:
# f.write(onx.SerializeToString())
sess = InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
Expand Down

0 comments on commit e0799d3

Please sign in to comment.