Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Jan 24, 2024
1 parent d5f47ea commit 27e36fe
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/test_sklearn_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def test_model_stacking_classifier_nozipmap_passthrough(self):

@unittest.skipIf(StackingClassifier is None, reason="new in 0.22")
@ignore_warnings(category=FutureWarning)
@unittest.skipIf(not skl12(), reason="sparse_output")
def test_issue_786_exc(self):
pipeline = make_pipeline(
OneHotEncoder(handle_unknown="ignore", sparse_output=False),
Expand Down
8 changes: 8 additions & 0 deletions tests/test_utils_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy
import pandas
from onnxruntime import __version__ as ort_version
from sklearn import __version__ as skl_version
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import load_iris
Expand Down Expand Up @@ -41,6 +42,12 @@
ort_version = ort_version.split("+")[0]


def skl12():
# pv.Version does not work with development versions
vers = ".".join(skl_version.split(".")[:2])
return pv.Version(vers) >= pv.Version("1.2")


class TestUtilsSklearn(unittest.TestCase):
@unittest.skipIf(VotingRegressor is None, reason="new in 0.21")
def test_voting_regression(self):
Expand Down Expand Up @@ -120,6 +127,7 @@ def test_pipeline_lr(self):
@unittest.skipIf(
pv.Version(ort_version) <= pv.Version("0.4.0"), reason="onnxruntime too old"
)
@unittest.skipIf(not skl12(), reason="sparse_output")
def test_pipeline_column_transformer(self):
iris = load_iris()
X = iris.data[:, :3]
Expand Down

0 comments on commit 27e36fe

Please sign in to comment.