Skip to content

Commit

Permalink
Check scikit-learn==1.5.0
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed May 22, 2024
2 parents be3e3da + 49473d6 commit 293391d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/linux-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ jobs:
os: [ubuntu-latest]
python_version: ['3.12', '3.11', '3.10', '3.9']
include:
- sklearn_version: '==1.5.0'
documentation: 0
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx==1.16.0'
onnxrt_version: 'onnxruntime==1.18.0'
python_version: '3.12'
- python_version: '3.12'
documentation: 0
numpy_version: '>=1.21.1'
Expand Down
14 changes: 13 additions & 1 deletion .github/workflows/windows-macos-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@ jobs:
os: [windows-latest, macos-latest]
python_version: ['3.11', '3.10', '3.9']
include:
- sklearn_version: '==1.5.0'
python_version: '3.11'
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx==1.16.0'
onnxrt_version: 'onnxruntime==1.18.0'
- python_version: '3.11'
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx<1.16.0'
onnxrt_version: 'onnxruntime==1.17.3'
onnxrt_version: 'onnxruntime<1.18.0'
sklearn_version: '==1.3.2'
- python_version: '3.10'
numpy_version: '>=1.21.1'
Expand All @@ -27,6 +33,12 @@ jobs:
onnx_version: 'onnx<1.14'
onnxrt_version: 'onnxruntime<1.16.0'
sklearn_version: '==1.2.2'
- sklearn_version: '==1.4.2'
python_version: '3.11'
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx>=1.16.0'
onnxrt_version: 'onnxruntime>=1.18.0'

steps:
- name: Checkout repository
Expand Down
24 changes: 19 additions & 5 deletions skl2onnx/operator_converters/linear_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,26 @@ def convert_sklearn_linear_classifier(
intercepts = list(map(lambda x: -1 * x, intercepts)) + intercepts

multi_class = 0
use_ovr = False
if hasattr(op, "multi_class"):
if op.multi_class == "ovr":
multi_class = 1
else:
elif number_of_classes > 2:
# See https://scikit-learn.org/dev/modules/generated/
# sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression
# multi_class attribute is deprecated.
# OVR is not supported anymore.
multi_class = 2
use_ovr = op.multi_class in ["ovr", "warn"] or (
op.multi_class == "auto"
and (op.classes_.size <= 2 or op.solver == "liblinear")
)
else:
# See https://scikit-learn.org/dev/modules/generated/
# sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression
# multi_class attribute is deprecated.
# OVR is not supported anymore.
if number_of_classes > 2:
multi_class = 2

classifier_type = "LinearClassifier"
Expand All @@ -77,11 +93,9 @@ def convert_sklearn_linear_classifier(
):
classifier_attrs["post_transform"] = "NONE"
elif isinstance(op, LogisticRegression):
ovr = op.multi_class in ["ovr", "warn"] or (
op.multi_class == "auto"
and (op.classes_.size <= 2 or op.solver == "liblinear")
classifier_attrs["post_transform"] = (
"LOGISTIC" if (use_ovr or multi_class == 0) else "SOFTMAX"
)
classifier_attrs["post_transform"] = "LOGISTIC" if ovr else "SOFTMAX"
else:
classifier_attrs["post_transform"] = (
"LOGISTIC" if multi_class > 2 else "SOFTMAX"
Expand Down

0 comments on commit 293391d

Please sign in to comment.