Skip to content

Commit

Permalink
OrdinalEncoder conversion fixes (#1044)
Browse files Browse the repository at this point in the history
* OrdinalEncoder conversion fixes

* Test fixes; added conversions for OrdinalEncoder input variables

* - Fixed OrdinalEncoder conversion. Now the type of keys for ONNX LabelEncoder is determined by the type of input variable, and not by the type of categories in OrdinalEncoder
- Fixed FunctionTransformer converter. Added axis to Concat operator
- Fixed Imputer shapes calculator. Now number of inputs can be >1
- Fixed SklearnMultiply converter. Now initializer type is equal to input type
- Fixed Pipeline converter. Now Cast operator applies if pipeline outputs are different with last stage outputs
- Changed VotingClassifier and VotingRegressor converter. Now VotingClassifier can accept number of inputs >1

* CI fix

* Run black formatter

* Changed assertTrue to assertIsNotNone

---------

Co-authored-by: Vershinin Maxim WX1123714 <[email protected]>
Co-authored-by: Xavier Dupré <[email protected]>
  • Loading branch information
3 people authored Dec 7, 2023
1 parent 943cac2 commit 78933db
Show file tree
Hide file tree
Showing 16 changed files with 282 additions and 143 deletions.
6 changes: 4 additions & 2 deletions skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class OutlierMixin:
from sklearn.neighbors import NearestNeighbors, LocalOutlierFactor
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
from sklearn.multioutput import MultiOutputClassifier
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC, NuSVC, SVC

Expand Down Expand Up @@ -65,7 +65,9 @@ class OutlierMixin:


do_not_merge_columns = tuple(
filter(lambda op: op is not None, [OneHotEncoder, ColumnTransformer])
filter(
lambda op: op is not None, [OrdinalEncoder, OneHotEncoder, ColumnTransformer]
)
)


Expand Down
8 changes: 8 additions & 0 deletions skl2onnx/common/shape_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def _calculate_linear_classifier_output_shapes(
],
)

_infer_linear_classifier_output_types(operator)


def _infer_linear_classifier_output_types(operator):
N = operator.inputs[0].get_first_dimension()
op = operator.raw_operator
class_labels = get_label_classes(operator.scope_inst, op)
Expand Down Expand Up @@ -144,6 +148,10 @@ def _calculate_linear_regressor_output_shapes(operator, enable_type_checking=Tru
],
)

_infer_linear_regressor_output_types(operator)


def _infer_linear_regressor_output_types(operator):
inp0 = operator.inputs[0].type
if isinstance(inp0, (FloatTensorType, DoubleTensorType)):
cls_type = inp0.__class__
Expand Down
1 change: 1 addition & 0 deletions skl2onnx/operator_converters/function_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def convert_sklearn_function_transformer(
[i.full_name for i in operator.inputs],
operator.outputs[0].full_name,
container,
axis=1,
)


Expand Down
27 changes: 14 additions & 13 deletions skl2onnx/operator_converters/multiply_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,25 @@
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..proto import onnx_proto
from ..common.data_types import guess_proto_type


def convert_sklearn_multiply(
scope: Scope, operator: Operator, container: ModelComponentContainer
):
operand_name = scope.get_unique_variable_name("operand")

container.add_initializer(
operand_name, onnx_proto.TensorProto.FLOAT, [], [operator.operand]
)

apply_mul(
scope,
[operator.inputs[0].full_name, operand_name],
operator.outputs[0].full_name,
container,
)
for input, output in zip(operator.inputs, operator.outputs):
operand_name = scope.get_unique_variable_name("operand")

container.add_initializer(
operand_name, guess_proto_type(input.type), [], [operator.operand]
)

apply_mul(
scope,
[input.full_name, operand_name],
output.full_name,
container,
)


register_converter("SklearnMultiply", convert_sklearn_multiply)
161 changes: 88 additions & 73 deletions skl2onnx/operator_converters/ordinal_encoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# SPDX-License-Identifier: Apache-2.0

import copy

import numpy as np

from ..common._apply_operation import apply_cast, apply_concat, apply_reshape
from ..common.data_types import Int64TensorType, StringTensorType
from ..common._container import ModelComponentContainer
from ..common.data_types import (
DoubleTensorType,
FloatTensorType,
Int64TensorType,
Int32TensorType,
Int16TensorType,
)
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..proto import onnx_proto


Expand All @@ -15,92 +22,100 @@ def convert_sklearn_ordinal_encoder(
):
ordinal_op = operator.raw_operator
result = []
concatenated_input_name = operator.inputs[0].full_name
concat_result_name = scope.get_unique_variable_name("concat_result")
input_idx = 0
dimension_idx = 0
for categories in ordinal_op.categories_:
if len(categories) == 0:
continue

if len(operator.inputs) > 1:
concatenated_input_name = scope.get_unique_variable_name("concatenated_input")
if all(
isinstance(inp.type, type(operator.inputs[0].type))
for inp in operator.inputs
):
input_names = list(map(lambda x: x.full_name, operator.inputs))
current_input = operator.inputs[input_idx]
if current_input.get_second_dimension() == 1:
feature_column = current_input
input_idx += 1
else:
input_names = []
for inp in operator.inputs:
if isinstance(inp.type, Int64TensorType):
input_names.append(scope.get_unique_variable_name("cast_input"))
apply_cast(
scope,
inp.full_name,
input_names[-1],
container,
to=onnx_proto.TensorProto.STRING,
)
elif isinstance(inp.type, StringTensorType):
input_names.append(inp.full_name)
else:
raise NotImplementedError(
"{} input datatype not yet supported. "
"You may raise an issue at "
"https://github.com/onnx/sklearn-onnx/issues"
"".format(type(inp.type))
)

apply_concat(scope, input_names, concatenated_input_name, container, axis=1)
if len(ordinal_op.categories_) == 0:
raise RuntimeError(
"No categories found in type=%r, encoder=%r."
% (type(ordinal_op), ordinal_op)
)
for index, categories in enumerate(ordinal_op.categories_):
attrs = {"name": scope.get_unique_operator_name("LabelEncoder")}
if len(categories) > 0:
if (
np.issubdtype(categories.dtype, np.floating)
or categories.dtype == np.bool_
):
attrs["keys_floats"] = categories
elif np.issubdtype(categories.dtype, np.signedinteger):
attrs["keys_int64s"] = categories
else:
attrs["keys_strings"] = np.array(
[str(s).encode("utf-8") for s in categories]
)
attrs["values_int64s"] = np.arange(len(categories)).astype(np.int64)

index_name = scope.get_unique_variable_name("index")
feature_column_name = scope.get_unique_variable_name("feature_column")
result.append(scope.get_unique_variable_name("ordinal_output"))
label_encoder_output = scope.get_unique_variable_name("label_encoder")

container.add_initializer(
index_name, onnx_proto.TensorProto.INT64, [], [index]
index_name, onnx_proto.TensorProto.INT64, [], [dimension_idx]
)

feature_column = scope.declare_local_variable(
"feature_column",
current_input.type.__class__([current_input.get_first_dimension(), 1]),
)

container.add_node(
"ArrayFeatureExtractor",
[concatenated_input_name, index_name],
feature_column_name,
[current_input.onnx_name, index_name],
feature_column.onnx_name,
op_domain="ai.onnx.ml",
name=scope.get_unique_operator_name("ArrayFeatureExtractor"),
)

container.add_node(
"LabelEncoder",
feature_column_name,
label_encoder_output,
op_domain="ai.onnx.ml",
op_version=2,
**attrs
dimension_idx += 1
if dimension_idx == current_input.get_second_dimension():
dimension_idx = 0
input_idx += 1

to = None
if isinstance(current_input.type, DoubleTensorType):
to = onnx_proto.TensorProto.FLOAT
if isinstance(current_input.type, (Int16TensorType, Int32TensorType)):
to = onnx_proto.TensorProto.INT64
if to is not None:
dtype = (
Int64TensorType
if to == onnx_proto.TensorProto.INT64
else FloatTensorType
)
casted_feature_column = scope.declare_local_variable(
"casted_feature_column", dtype(copy.copy(feature_column.type.shape))
)
apply_reshape(

apply_cast(
scope,
label_encoder_output,
result[-1],
feature_column.onnx_name,
casted_feature_column.onnx_name,
container,
desired_shape=(-1, 1),
to=to,
)

feature_column = casted_feature_column

attrs = {"name": scope.get_unique_operator_name("LabelEncoder")}
if isinstance(feature_column.type, FloatTensorType):
attrs["keys_floats"] = np.array(
[float(s) for s in categories], dtype=np.float32
)
elif isinstance(feature_column.type, Int64TensorType):
attrs["keys_int64s"] = np.array(
[int(s) for s in categories], dtype=np.int64
)
else:
attrs["keys_strings"] = np.array(
[str(s).encode("utf-8") for s in categories]
)
attrs["values_int64s"] = np.arange(len(categories)).astype(np.int64)

result.append(scope.get_unique_variable_name("ordinal_output"))
label_encoder_output = scope.get_unique_variable_name("label_encoder")

container.add_node(
"LabelEncoder",
feature_column.onnx_name,
label_encoder_output,
op_domain="ai.onnx.ml",
op_version=2,
**attrs
)
apply_reshape(
scope,
label_encoder_output,
result[-1],
container,
desired_shape=(-1, 1),
)

concat_result_name = scope.get_unique_variable_name("concat_result")
apply_concat(scope, result, concat_result_name, container, axis=1)
cast_type = (
onnx_proto.TensorProto.FLOAT
Expand Down
27 changes: 21 additions & 6 deletions skl2onnx/operator_converters/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..common._apply_operation import apply_cast
from ..common.data_types import guess_proto_type
from .._parse import _parse_sklearn


Expand All @@ -25,12 +27,25 @@ def convert_pipeline(
"last step outputs %d." % (len(outputs), len(operator.outputs))
)
for fr, to in zip(outputs, operator.outputs):
container.add_node(
"Identity",
fr.full_name,
to.full_name,
name=scope.get_unique_operator_name("Id" + operator.onnx_name),
)
if isinstance(to.type, type(fr.type)):
container.add_node(
"Identity",
fr.full_name,
to.full_name,
name=scope.get_unique_operator_name("Id" + operator.onnx_name),
)
else:
# If Pipeline output types are different with last stage output type
apply_cast(
scope,
fr.full_name,
to.full_name,
container,
operator_name=scope.get_unique_operator_name(
"Cast" + operator.onnx_name
),
to=guess_proto_type(to.type),
)


def convert_feature_union(
Expand Down
4 changes: 2 additions & 2 deletions skl2onnx/operator_converters/voting_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def convert_voting_classifier(
operator.raw_operator.__class__.__name__
)
)
proto_dtype = guess_proto_type(operator.inputs[0].type)
proto_dtype = guess_proto_type(operator.outputs[1].type)
if proto_dtype != onnx_proto.TensorProto.DOUBLE:
proto_dtype = onnx_proto.TensorProto.FLOAT
op = operator.raw_operator
Expand All @@ -60,7 +60,7 @@ def convert_voting_classifier(

label_name = scope.declare_local_variable("label_%d" % i, Int64TensorType())
prob_name = scope.declare_local_variable(
"voting_proba_%d" % i, operator.inputs[0].type.__class__()
"voting_proba_%d" % i, operator.outputs[1].type.__class__()
)
this_operator.outputs.append(label_name)
this_operator.outputs.append(prob_name)
Expand Down
17 changes: 4 additions & 13 deletions skl2onnx/operator_converters/voting_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..common._apply_operation import apply_mul
from ..common.data_types import guess_proto_type, FloatTensorType, DoubleTensorType
from ..common.data_types import guess_proto_type
from .._supported_operators import sklearn_operator_name_map


Expand All @@ -16,15 +16,7 @@ def convert_voting_regressor(
Converts a *VotingRegressor* into *ONNX* format.
"""
op = operator.raw_operator

if not isinstance(operator.inputs[0].type, (FloatTensorType, DoubleTensorType)):
this_operator = scope.declare_local_operator("SklearnCast")
this_operator.inputs = operator.inputs
var_name = scope.declare_local_variable("cast", FloatTensorType())
this_operator.outputs.append(var_name)
inputs = this_operator.outputs
else:
inputs = operator.inputs
proto_dtype = guess_proto_type(operator.outputs[0].type)

vars_names = []
for i, estimator in enumerate(op.estimators_):
Expand All @@ -34,10 +26,10 @@ def convert_voting_regressor(
op_type = sklearn_operator_name_map[type(estimator)]

this_operator = scope.declare_local_operator(op_type, estimator)
this_operator.inputs = inputs
this_operator.inputs = operator.inputs

var_name = scope.declare_local_variable(
"var_%d" % i, inputs[0].type.__class__()
"var_%d" % i, operator.outputs[0].type.__class__()
)
this_operator.outputs.append(var_name)
var_name = var_name.onnx_name
Expand All @@ -48,7 +40,6 @@ def convert_voting_regressor(
val = 1.0 / len(op.estimators_)

weights_name = scope.get_unique_variable_name("w%d" % i)
proto_dtype = guess_proto_type(inputs[0].type)
container.add_initializer(weights_name, proto_dtype, [1], [val])
wvar_name = scope.get_unique_variable_name("wvar_%d" % i)
apply_mul(scope, [var_name, weights_name], wvar_name, container, broadcast=1)
Expand Down
Loading

0 comments on commit 78933db

Please sign in to comment.