From 922ab4d27dfb3f5c236bb8e3027febe2f6e1bcef Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Thu, 8 Feb 2024 14:47:56 +0100 Subject: [PATCH 1/2] Fix unexpected type for intercept in PoissonRegressor and GammaRegressor Signed-off-by: Xavier Dupre --- pyproject.toml | 4 +-- .../operator_converters/gamma_regressor.py | 2 +- .../operator_converters/linear_regressor.py | 2 +- tests/test_sklearn_gamma_regressor.py | 34 ++++++++++++++++++- 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c04c85d1c..86f0d9b0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,8 @@ exclude = [ # Same as Black. line-length = 88 -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] max-complexity = 10 -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "skl2onnx/algebra/onnx_ops.py" = ["F821"] diff --git a/skl2onnx/operator_converters/gamma_regressor.py b/skl2onnx/operator_converters/gamma_regressor.py index 030eaa47b..1676fa37f 100644 --- a/skl2onnx/operator_converters/gamma_regressor.py +++ b/skl2onnx/operator_converters/gamma_regressor.py @@ -34,7 +34,7 @@ def convert_sklearn_gamma_regressor( intercept = ( op.intercept_.astype(dtype) - if len(op.intercept_.shape) > 0 + if isinstance(op.intercept_, np.ndarray) and len(op.intercept_.shape) > 0 else np.array([op.intercept_], dtype=dtype) ) eta = OnnxAdd( diff --git a/skl2onnx/operator_converters/linear_regressor.py b/skl2onnx/operator_converters/linear_regressor.py index 06055bfc2..3c99c99f8 100644 --- a/skl2onnx/operator_converters/linear_regressor.py +++ b/skl2onnx/operator_converters/linear_regressor.py @@ -198,7 +198,7 @@ def convert_sklearn_poisson_regressor( intercept = ( op.intercept_.astype(dtype) - if len(op.intercept_.shape) > 0 + if isinstance(op.intercept_, np.ndarray) and len(op.intercept_.shape) > 0 else np.array([op.intercept_], dtype=dtype) ) eta = OnnxAdd( diff --git a/tests/test_sklearn_gamma_regressor.py b/tests/test_sklearn_gamma_regressor.py index e09a3cc48..4806b73ba 100644 --- a/tests/test_sklearn_gamma_regressor.py +++ b/tests/test_sklearn_gamma_regressor.py @@ -6,7 +6,7 @@ import numpy as np try: - from sklearn.linear_model import GammaRegressor + from sklearn.linear_model import GammaRegressor, PoissonRegressor except ImportError: GammaRegressor = None from onnxruntime import __version__ as ort_version @@ -90,6 +90,38 @@ def test_gamma_regressor_double(self): basename="SklearnGammaRegressor", ) + @unittest.skipIf(GammaRegressor is None, reason="scikit-learn<1.0") + def test_poisson_without_intercept(self): + # Poisson + model = PoissonRegressor(fit_intercept=False) + X = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 3.0]]) + y = np.array([19.0, 26.0, 33.0, 30.0]) + model.fit(X, y) + + model_onnx = convert_sklearn( + model, + "scikit-learn Poisson Regressor without Intercept", + [("input", FloatTensorType([None, X.shape[1]]))], + ) + + self.assertIsNotNone(model_onnx is not None) + + @unittest.skipIf(GammaRegressor is None, reason="scikit-learn<1.0") + def test_gamma_without_intercept(self): + # Gamma + model = GammaRegressor(fit_intercept=False) + X = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 3.0]]) + y = np.array([19.0, 26.0, 33.0, 30.0]) + model.fit(X, y) + + model_onnx = convert_sklearn( + model, + "scikit-learn Gamma Regressor without Intercept", + [("input", FloatTensorType([None, X.shape[1]]))], + ) + + self.assertIsNotNone(model_onnx is not None) + if __name__ == "__main__": unittest.main(verbosity=3) From 242aef54d6e2873ae75e2f9863206764b9188a0d Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 13 Feb 2024 11:25:08 +0100 Subject: [PATCH 2/2] fix changelogs Signed-off-by: Xavier Dupre --- CHANGELOGS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOGS.md b/CHANGELOGS.md index c4f374b85..ed5de5ed5 100644 --- a/CHANGELOGS.md +++ b/CHANGELOGS.md @@ -2,6 +2,8 @@ ## 1.17.0 (development) +* Fix unexpected type for intercept in PoissonRegressor and GammaRegressor + [#1070](https://github.com/onnx/sklearn-onnx/pull/1070) * Add support for scikti-learn 1.4.0, [#1058](https://github.com/onnx/sklearn-onnx/pull/1058), fixes issues [Many examples in the gallery are showing "broken"](https://github.com/onnx/sklearn-onnx/pull/1057),