From 03db70c8ea2468b51522fdc9957d903f8d316c08 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 22 May 2024 12:35:59 +0200 Subject: [PATCH] fix disc Signed-off-by: Xavier Dupre --- skl2onnx/operator_converters/linear_classifier.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/skl2onnx/operator_converters/linear_classifier.py b/skl2onnx/operator_converters/linear_classifier.py index d1e20ba3f..300ee7837 100644 --- a/skl2onnx/operator_converters/linear_classifier.py +++ b/skl2onnx/operator_converters/linear_classifier.py @@ -94,13 +94,24 @@ def convert_sklearn_linear_classifier( classifier_attrs["post_transform"] = "NONE" elif isinstance(op, LogisticRegression): classifier_attrs["post_transform"] = ( - "LOGISTIC" if (use_ovr or multi_class == 0) else "SOFTMAX" + "LOGISTIC" + if (use_ovr or (multi_class == 0 and op.intercept_.size <= 1)) + else "SOFTMAX" ) else: classifier_attrs["post_transform"] = ( "LOGISTIC" if multi_class > 2 else "SOFTMAX" ) + print( + "***", + multi_class, + classifier_attrs["post_transform"], + type(op), + number_of_classes, + op.__dict__, + ) + if all(isinstance(i, str) for i in classes): class_labels = [str(i) for i in classes] classifier_attrs["classlabels_strings"] = class_labels