Skip to content

Commit

Permalink
- Fixed OrdinalEncoder conversion. Now the type of keys for ONNX Labe…
Browse files Browse the repository at this point in the history
…lEncoder is determined by the type of input variable, and not by the type of categories in OrdinalEncoder

- Fixed FunctionTransformer converter. Added axis to Concat operator
- Fixed Imputer shapes calculator. Now number of inputs can be >1
- Fixed SklearnMultiply converter. Now initializer type is equal to input type
- Fixed Pipeline converter. Now Cast operator applies if pipeline outputs are different with last stage outputs
- Changed VotingClassifier and VotingRegressor converter. Now VotingClassifier can accept number of inputs >1
  • Loading branch information
max-509 committed Nov 19, 2023
1 parent bc4e95c commit f3ec9ac
Show file tree
Hide file tree
Showing 14 changed files with 161 additions and 95 deletions.
10 changes: 9 additions & 1 deletion skl2onnx/common/shape_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def _calculate_linear_classifier_output_shapes(
],
)

_infer_linear_classifier_output_types(operator)


def _infer_linear_classifier_output_types(operator):
N = operator.inputs[0].get_first_dimension()
op = operator.raw_operator
class_labels = get_label_classes(operator.scope_inst, op)
Expand All @@ -78,7 +82,7 @@ def _calculate_linear_classifier_output_shapes(
shape = (
[len(op.classes_), N, max([len(x) for x in op.classes_])]
if isinstance(op.classes_, list)
and isinstance(op.classes_[0], np.ndarray)
and isinstance(op.classes_[0], np.ndarray)
else [N, number_of_classes]
)
operator.outputs[1].type.shape = shape
Expand Down Expand Up @@ -144,6 +148,10 @@ def _calculate_linear_regressor_output_shapes(operator, enable_type_checking=Tru
],
)

_infer_linear_regressor_output_types(operator)


def _infer_linear_regressor_output_types(operator):
inp0 = operator.inputs[0].type
if isinstance(inp0, (FloatTensorType, DoubleTensorType)):
cls_type = inp0.__class__
Expand Down
1 change: 1 addition & 0 deletions skl2onnx/operator_converters/function_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def convert_sklearn_function_transformer(
[i.full_name for i in operator.inputs],
operator.outputs[0].full_name,
container,
axis=1,
)


Expand Down
27 changes: 14 additions & 13 deletions skl2onnx/operator_converters/multiply_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,25 @@
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..proto import onnx_proto
from ..common.data_types import guess_proto_type


def convert_sklearn_multiply(
scope: Scope, operator: Operator, container: ModelComponentContainer
):
operand_name = scope.get_unique_variable_name("operand")

container.add_initializer(
operand_name, onnx_proto.TensorProto.FLOAT, [], [operator.operand]
)

apply_mul(
scope,
[operator.inputs[0].full_name, operand_name],
operator.outputs[0].full_name,
container,
)
for input, output in zip(operator.inputs, operator.outputs):
operand_name = scope.get_unique_variable_name("operand")

container.add_initializer(
operand_name, guess_proto_type(input.type), [], [operator.operand]
)

apply_mul(
scope,
[input.full_name, operand_name],
output.full_name,
container,
)


register_converter("SklearnMultiply", convert_sklearn_multiply)
42 changes: 19 additions & 23 deletions skl2onnx/operator_converters/ordinal_encoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# SPDX-License-Identifier: Apache-2.0

import copy

import numpy as np

from ..common._apply_operation import apply_cast, apply_concat, apply_reshape
from ..common._container import ModelComponentContainer
from ..common.data_types import DoubleTensorType, Int32TensorType, \
Int16TensorType
from ..common.data_types import DoubleTensorType, FloatTensorType, \
Int64TensorType, Int32TensorType, Int16TensorType
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..proto import onnx_proto
Expand All @@ -25,21 +25,22 @@ def convert_sklearn_ordinal_encoder(

current_input = operator.inputs[input_idx]
if current_input.get_second_dimension() == 1:
feature_column_name = current_input.onnx_name
feature_column = current_input
input_idx += 1
else:
index_name = scope.get_unique_variable_name("index")
container.add_initializer(
index_name, onnx_proto.TensorProto.INT64, [], [dimension_idx]
)

feature_column_name = scope.get_unique_variable_name(
"feature_column")
feature_column = scope.declare_local_variable(
"feature_column", current_input.type.__class__([current_input.get_first_dimension(), 1])
)

container.add_node(
"ArrayFeatureExtractor",
[current_input.onnx_name, index_name],
feature_column_name,
feature_column.onnx_name,
op_domain="ai.onnx.ml",
name=scope.get_unique_operator_name("ArrayFeatureExtractor"),
)
Expand All @@ -55,27 +56,22 @@ def convert_sklearn_ordinal_encoder(
if isinstance(current_input.type, (Int16TensorType, Int32TensorType)):
to = onnx_proto.TensorProto.INT64
if to is not None:
casted_feature_column_name = scope.get_unique_variable_name(
'casted_feature_column')
dtype = Int64TensorType if to == onnx_proto.TensorProto.INT64 else FloatTensorType
casted_feature_column = scope.declare_local_variable(
"casted_feature_column", dtype(copy.copy(feature_column.type.shape))
)

apply_cast(
scope, feature_column_name, casted_feature_column_name,
scope, feature_column.onnx_name, casted_feature_column.onnx_name,
container, to=to)

feature_column_name = casted_feature_column_name
feature_column = casted_feature_column

attrs = {"name": scope.get_unique_operator_name("LabelEncoder")}
if (
np.issubdtype(categories.dtype, np.floating)
or categories.dtype == np.bool_
or isinstance(categories[0], float)
):
attrs["keys_floats"] = categories
elif (
np.issubdtype(categories.dtype, np.signedinteger)
or isinstance(categories[0], int)
):
attrs["keys_int64s"] = categories
if isinstance(feature_column.type, FloatTensorType):
attrs["keys_floats"] = np.array([float(s) for s in categories], dtype=np.float32)
elif isinstance(feature_column.type, Int64TensorType):
attrs["keys_int64s"] = np.array([int(s) for s in categories], dtype=np.int64)
else:
attrs["keys_strings"] = np.array(
[str(s).encode("utf-8") for s in categories]
Expand All @@ -88,7 +84,7 @@ def convert_sklearn_ordinal_encoder(

container.add_node(
"LabelEncoder",
feature_column_name,
feature_column.onnx_name,
label_encoder_output,
op_domain="ai.onnx.ml",
op_version=2,
Expand Down
26 changes: 17 additions & 9 deletions skl2onnx/operator_converters/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..common._apply_operation import apply_cast
from ..common.data_types import guess_proto_type
from .._parse import _parse_sklearn


def convert_pipeline(
scope: Scope, operator: Operator, container: ModelComponentContainer
scope: Scope, operator: Operator, container: ModelComponentContainer
):
model = operator.raw_operator
inputs = operator.inputs
Expand All @@ -25,24 +27,30 @@ def convert_pipeline(
"last step outputs %d." % (len(outputs), len(operator.outputs))
)
for fr, to in zip(outputs, operator.outputs):
container.add_node(
"Identity",
fr.full_name,
to.full_name,
name=scope.get_unique_operator_name("Id" + operator.onnx_name),
)
if isinstance(to.type, type(fr.type)):
container.add_node(
"Identity",
fr.full_name,
to.full_name,
name=scope.get_unique_operator_name("Id" + operator.onnx_name),
)
else:
# If Pipeline output types are different with last stage output type
apply_cast(scope, fr.full_name, to.full_name, container,
operator_name=scope.get_unique_operator_name("Cast" + operator.onnx_name),
to=guess_proto_type(to.type))


def convert_feature_union(
scope: Scope, operator: Operator, container: ModelComponentContainer
scope: Scope, operator: Operator, container: ModelComponentContainer
):
raise NotImplementedError(
"This converter not needed so far. It is usually handled " "during parsing."
)


def convert_column_transformer(
scope: Scope, operator: Operator, container: ModelComponentContainer
scope: Scope, operator: Operator, container: ModelComponentContainer
):
raise NotImplementedError(
"This converter not needed so far. It is usually handled " "during parsing."
Expand Down
4 changes: 2 additions & 2 deletions skl2onnx/operator_converters/voting_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def convert_voting_classifier(
operator.raw_operator.__class__.__name__
)
)
proto_dtype = guess_proto_type(operator.inputs[0].type)
proto_dtype = guess_proto_type(operator.outputs[1].type)
if proto_dtype != onnx_proto.TensorProto.DOUBLE:
proto_dtype = onnx_proto.TensorProto.FLOAT
op = operator.raw_operator
Expand All @@ -60,7 +60,7 @@ def convert_voting_classifier(

label_name = scope.declare_local_variable("label_%d" % i, Int64TensorType())
prob_name = scope.declare_local_variable(
"voting_proba_%d" % i, operator.inputs[0].type.__class__()
"voting_proba_%d" % i, operator.outputs[1].type.__class__()
)
this_operator.outputs.append(label_name)
this_operator.outputs.append(prob_name)
Expand Down
17 changes: 4 additions & 13 deletions skl2onnx/operator_converters/voting_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..common._apply_operation import apply_mul
from ..common.data_types import guess_proto_type, FloatTensorType, DoubleTensorType
from ..common.data_types import guess_proto_type
from .._supported_operators import sklearn_operator_name_map


Expand All @@ -16,15 +16,7 @@ def convert_voting_regressor(
Converts a *VotingRegressor* into *ONNX* format.
"""
op = operator.raw_operator

if not isinstance(operator.inputs[0].type, (FloatTensorType, DoubleTensorType)):
this_operator = scope.declare_local_operator("SklearnCast")
this_operator.inputs = operator.inputs
var_name = scope.declare_local_variable("cast", FloatTensorType())
this_operator.outputs.append(var_name)
inputs = this_operator.outputs
else:
inputs = operator.inputs
proto_dtype = guess_proto_type(operator.outputs[0].type)

vars_names = []
for i, estimator in enumerate(op.estimators_):
Expand All @@ -34,10 +26,10 @@ def convert_voting_regressor(
op_type = sklearn_operator_name_map[type(estimator)]

this_operator = scope.declare_local_operator(op_type, estimator)
this_operator.inputs = inputs
this_operator.inputs = operator.inputs

var_name = scope.declare_local_variable(
"var_%d" % i, inputs[0].type.__class__()
"var_%d" % i, operator.outputs[0].type.__class__()
)
this_operator.outputs.append(var_name)
var_name = var_name.onnx_name
Expand All @@ -48,7 +40,6 @@ def convert_voting_regressor(
val = 1.0 / len(op.estimators_)

weights_name = scope.get_unique_variable_name("w%d" % i)
proto_dtype = guess_proto_type(inputs[0].type)
container.add_initializer(weights_name, proto_dtype, [1], [val])
wvar_name = scope.get_unique_variable_name("wvar_%d" % i)
apply_mul(scope, [var_name, weights_name], wvar_name, container, broadcast=1)
Expand Down
1 change: 1 addition & 0 deletions skl2onnx/shape_calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from . import local_outlier_factor
from . import mixture
from . import multioutput
from . import multiply
from . import nearest_neighbours
from . import one_hot_encoder
from . import ordinal_encoder
Expand Down
1 change: 0 additions & 1 deletion skl2onnx/shape_calculators/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def more_generic(t1, t2):

register_shape_calculator("SklearnConcat", calculate_sklearn_concat)
register_shape_calculator("SklearnGenericUnivariateSelect", calculate_sklearn_concat)
register_shape_calculator("SklearnMultiply", calculate_sklearn_concat)
register_shape_calculator("SklearnRFE", calculate_sklearn_concat)
register_shape_calculator("SklearnRFECV", calculate_sklearn_concat)
register_shape_calculator("SklearnSelectFdr", calculate_sklearn_concat)
Expand Down
18 changes: 10 additions & 8 deletions skl2onnx/shape_calculators/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def calculate_sklearn_imputer_output_shapes(operator):
them along C-axis. The produced tensor's shape is used as the
output shape.
"""
check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
check_input_and_output_numbers(operator, input_count_range=[1, None], output_count_range=1)
check_input_and_output_types(
operator,
good_input_types=[
Expand All @@ -31,12 +31,14 @@ def calculate_sklearn_imputer_output_shapes(operator):
StringTensorType,
],
)
if not isinstance(operator.inputs[0].type, type(operator.outputs[0].type)): # noqa
raise RuntimeError(
"Inputs and outputs should have the same type "
"%r != %r."
% (type(operator.inputs[0].type), type(operator.outputs[0].type))
)
output = operator.outputs[0]
for variable in operator.inputs:
if not isinstance(variable.type, type(output.type)): # noqa
raise RuntimeError(
"Inputs and outputs should have the same type "
"%r != %r."
% (type(variable.type), type(output.type))
)

N = operator.inputs[0].get_first_dimension()
C = 0
Expand All @@ -47,7 +49,7 @@ def calculate_sklearn_imputer_output_shapes(operator):
C = None
break

operator.outputs[0].type.shape = [N, C]
output.type.shape = [N, C]


register_shape_calculator("SklearnImputer", calculate_sklearn_imputer_output_shapes)
Expand Down
12 changes: 12 additions & 0 deletions skl2onnx/shape_calculators/multiply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
import copy

from ..common._registration import register_shape_calculator


def calculate_sklearn_multiply(operator):
for variable, output in zip(operator.inputs, operator.outputs):
output.type = copy.copy(variable.type)


register_shape_calculator("SklearnMultiply", calculate_sklearn_multiply)
9 changes: 5 additions & 4 deletions skl2onnx/shape_calculators/voting_classifier.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0

from ..common._registration import register_shape_calculator
from ..common.shape_calculator import _calculate_linear_classifier_output_shapes
from ..common.utils import check_input_and_output_numbers
from ..common.shape_calculator import _infer_linear_classifier_output_types


def voting_classifier_shape_calculator(operator):
return _calculate_linear_classifier_output_shapes(
operator, enable_type_checking=False
)
check_input_and_output_numbers(operator, output_count_range=2)

_infer_linear_classifier_output_types(operator)


register_shape_calculator("SklearnVotingClassifier", voting_classifier_shape_calculator)
8 changes: 4 additions & 4 deletions skl2onnx/shape_calculators/voting_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@


from ..common._registration import register_shape_calculator
from ..common.shape_calculator import _calculate_linear_regressor_output_shapes
from ..common.utils import check_input_and_output_numbers
from ..common.shape_calculator import _infer_linear_regressor_output_types


def voting_regressor_shape_calculator(operator):
return _calculate_linear_regressor_output_shapes(
operator, enable_type_checking=False
)
check_input_and_output_numbers(operator, output_count_range=1)
return _infer_linear_regressor_output_types(operator)


register_shape_calculator("SklearnVotingRegressor", voting_regressor_shape_calculator)
Loading

0 comments on commit f3ec9ac

Please sign in to comment.