From 27e36fefdb7eacfb157c0225aac22dbce393109b Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 23 Jan 2024 18:36:46 +0100 Subject: [PATCH] fix lint Signed-off-by: Xavier Dupre --- tests/test_sklearn_stacking.py | 1 + tests/test_utils_sklearn.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/tests/test_sklearn_stacking.py b/tests/test_sklearn_stacking.py index 5f5aae450..9b701b97f 100644 --- a/tests/test_sklearn_stacking.py +++ b/tests/test_sklearn_stacking.py @@ -192,6 +192,7 @@ def test_model_stacking_classifier_nozipmap_passthrough(self): @unittest.skipIf(StackingClassifier is None, reason="new in 0.22") @ignore_warnings(category=FutureWarning) + @unittest.skipIf(not skl12(), reason="sparse_output") def test_issue_786_exc(self): pipeline = make_pipeline( OneHotEncoder(handle_unknown="ignore", sparse_output=False), diff --git a/tests/test_utils_sklearn.py b/tests/test_utils_sklearn.py index 6beb7c7f5..c6c1d3847 100644 --- a/tests/test_utils_sklearn.py +++ b/tests/test_utils_sklearn.py @@ -8,6 +8,7 @@ import numpy import pandas from onnxruntime import __version__ as ort_version +from sklearn import __version__ as skl_version from sklearn.linear_model import LinearRegression from sklearn.ensemble import RandomForestRegressor from sklearn.datasets import load_iris @@ -41,6 +42,12 @@ ort_version = ort_version.split("+")[0] +def skl12(): + # pv.Version does not work with development versions + vers = ".".join(skl_version.split(".")[:2]) + return pv.Version(vers) >= pv.Version("1.2") + + class TestUtilsSklearn(unittest.TestCase): @unittest.skipIf(VotingRegressor is None, reason="new in 0.21") def test_voting_regression(self): @@ -120,6 +127,7 @@ def test_pipeline_lr(self): @unittest.skipIf( pv.Version(ort_version) <= pv.Version("0.4.0"), reason="onnxruntime too old" ) + @unittest.skipIf(not skl12(), reason="sparse_output") def test_pipeline_column_transformer(self): iris = load_iris() X = iris.data[:, :3]