-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add converter for TunedThresholdClassifierCV (#1107)
* Add converter for TunedThresholdClassifierCV Signed-off-by: Xavier Dupre <[email protected]> * upgrade version Signed-off-by: Xavier Dupre <[email protected]> * documentation Signed-off-by: Xavier Dupre <[email protected]> * update numpy version Signed-off-by: Xavier Dupre <[email protected]> * do not use numpy 2 Signed-off-by: Xavier Dupre <[email protected]> * delay import Signed-off-by: Xavier Dupre <[email protected]> --------- Signed-off-by: Xavier Dupre <[email protected]>
- Loading branch information
Showing
10 changed files
with
140 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
42 changes: 42 additions & 0 deletions
42
skl2onnx/operator_converters/tuned_threshold_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from ..common._registration import register_converter | ||
from ..common._topology import Scope, Operator | ||
from ..common._container import ModelComponentContainer | ||
from ..common.data_types import Int64TensorType | ||
from .._supported_operators import sklearn_operator_name_map | ||
|
||
|
||
def convert_sklearn_tuned_threshold_classifier( | ||
scope: Scope, operator: Operator, container: ModelComponentContainer | ||
): | ||
estimator = operator.raw_operator.estimator_ | ||
op_type = sklearn_operator_name_map[type(estimator)] | ||
|
||
this_operator = scope.declare_local_operator(op_type, estimator) | ||
this_operator.inputs = operator.inputs | ||
|
||
label_name = scope.declare_local_variable("label_tuned", Int64TensorType()) | ||
prob_name = scope.declare_local_variable( | ||
"proba_tuned", operator.outputs[1].type.__class__() | ||
) | ||
this_operator.outputs.append(label_name) | ||
this_operator.outputs.append(prob_name) | ||
|
||
container.add_node( | ||
"Identity", [label_name.onnx_name], [operator.outputs[0].full_name] | ||
) | ||
container.add_node( | ||
"Identity", [prob_name.onnx_name], [operator.outputs[1].full_name] | ||
) | ||
|
||
|
||
register_converter( | ||
"SklearnTunedThresholdClassifierCV", | ||
convert_sklearn_tuned_threshold_classifier, | ||
options={ | ||
"zipmap": [True, False, "columns"], | ||
"output_class_labels": [False, True], | ||
"nocl": [True, False], | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from ..common._registration import register_shape_calculator | ||
from ..common.utils import check_input_and_output_numbers | ||
from ..common.shape_calculator import _infer_linear_classifier_output_types | ||
|
||
|
||
def tuned_threshold_classifier_shape_calculator(operator): | ||
check_input_and_output_numbers(operator, output_count_range=2) | ||
|
||
_infer_linear_classifier_output_types(operator) | ||
|
||
|
||
register_shape_calculator( | ||
"SklearnTunedThresholdClassifierCV", tuned_threshold_classifier_shape_calculator | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import unittest | ||
import numpy as np | ||
from sklearn.datasets import make_classification | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.utils._testing import ignore_warnings | ||
from skl2onnx import to_onnx | ||
from skl2onnx.common.data_types import FloatTensorType | ||
from test_utils import dump_data_and_model, TARGET_OPSET | ||
|
||
|
||
def has_tuned_theshold_classifier(): | ||
try: | ||
from sklearn.model_selection import TunedThresholdClassifierCV # noqa: F401 | ||
except ImportError: | ||
return False | ||
return True | ||
|
||
|
||
class TestSklearnTunedThresholdClassifierConverter(unittest.TestCase): | ||
@unittest.skipIf( | ||
not has_tuned_theshold_classifier(), | ||
reason="TunedThresholdClassifierCV not available", | ||
) | ||
@ignore_warnings(category=FutureWarning) | ||
def test_tuned_threshold_classifier(self): | ||
from sklearn.model_selection import TunedThresholdClassifierCV | ||
|
||
X, y = make_classification( | ||
n_samples=1_000, weights=[0.9, 0.1], class_sep=0.8, random_state=42 | ||
) | ||
X_train, X_test, y_train, y_test = train_test_split( | ||
X, y, stratify=y, random_state=42 | ||
) | ||
classifier = RandomForestClassifier(random_state=0) | ||
|
||
classifier_tuned = TunedThresholdClassifierCV( | ||
classifier, scoring="balanced_accuracy" | ||
).fit(X_train, y_train) | ||
|
||
model_onnx = to_onnx( | ||
classifier_tuned, | ||
initial_types=[("X", FloatTensorType([None, X_train.shape[1]]))], | ||
target_opset=TARGET_OPSET - 1, | ||
options={"zipmap": False}, | ||
) | ||
self.assertTrue(model_onnx is not None) | ||
dump_data_and_model( | ||
X_test[:10].astype(np.float32), | ||
classifier_tuned, | ||
model_onnx, | ||
basename="SklearnTunedThresholdClassifier", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main(verbosity=2) |