Skip to content

Commit

Permalink
fix many tiny bugs
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Jan 23, 2024
1 parent d0cbce2 commit 19d24b2
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
pandas.version: ''
lgbm.version: ''
onnxcc.version: '>=1.8.1'
run.example: '1'
run.example: '0'

Py311-Onnx140-Rt151-Skl130:
do.bench: '0'
Expand Down
14 changes: 6 additions & 8 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,12 @@
"sphinx.ext.graphviz",
"sphinx_skl2onnx_extension",
"matplotlib.sphinxext.plot_directive",
"pyquickhelper.sphinxext.sphinx_cmdref_extension",
"pyquickhelper.sphinxext.sphinx_collapse_extension",
"pyquickhelper.sphinxext.sphinx_docassert_extension",
"pyquickhelper.sphinxext.sphinx_epkg_extension",
"pyquickhelper.sphinxext.sphinx_exref_extension",
"pyquickhelper.sphinxext.sphinx_faqref_extension",
"pyquickhelper.sphinxext.sphinx_gdot_extension",
"pyquickhelper.sphinxext.sphinx_runpython_extension",
"sphinx_runpython.blocdefs.sphinx_exref_extension",
"sphinx_runpython.blocdefs.sphinx_faqref_extension",
"sphinx_runpython.blocdefs.sphinx_mathdef_extension",
"sphinx_runpython.epkg",
"sphinx_runpython.gdot",
"sphinx_runpython.runpython",
"sphinxcontrib.blockdiag",
]

Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ pandas
pydot
pyinstrument
pyod
pyquickhelper>=1.11.3762
pytest
pytest-cov
scikit-learn>=1.1
skl2onnx
sphinx
sphinxcontrib-blockdiag
sphinx-gallery
sphinx-runpython
tabulate
tqdm
wheel
Expand Down
2 changes: 2 additions & 0 deletions docs/tests/test_documentation_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def add_test_methods(cls):
if not name.endswith(".py") or not name.startswith("plot_"):
continue
reason = None
if name in {"plot_woe_transformer.py"}:
reason = "dot not available"

if reason:

Expand Down
29 changes: 0 additions & 29 deletions docs/tutorial/plot_gbegin_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
+++++++++++++++++++++++++
"""
from mlinsights.plotting import pipeline2dot
import numpy
import pprint
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from pandas import DataFrame
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.ensemble import RandomForestClassifier
from pyquickhelper.helpgen.graphviz_helper import plot_graphviz
from skl2onnx import to_onnx
from skl2onnx.algebra.type_helper import guess_initial_types

Expand Down Expand Up @@ -54,14 +51,6 @@
pipe = Pipeline([("preprocess", preprocessor), ("rf", RandomForestClassifier())])
pipe.fit(train_data, data["y"])

#####################################
# Display.

dot = pipeline2dot(pipe, train_data)
ax = plot_graphviz(dot)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)

#######################################
# Conversion to ONNX
# ++++++++++++++++++
Expand All @@ -84,24 +73,6 @@
except Exception as e:
print(e)

###########################
# Let's use a shortcut

oinf = ReferenceEvaluator(onx)
got = oinf.run(None, train_data)
print(pipe.predict(train_data))
print(got["label"])

#################################
# And probilities.

print(pipe.predict_proba(train_data))
print(got["probabilities"])

######################################
# It looks ok. Let's dig into the details to
# directly use *onnxruntime*.
#
# Unhide conversion logic with a dataframe
# ++++++++++++++++++++++++++++++++++++++++
#
Expand Down
2 changes: 1 addition & 1 deletion skl2onnx/algebra/type_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _guess_dim(value):
)


def guess_initial_types(X, initial_types):
def guess_initial_types(X, initial_types=None):
if X is None and initial_types is None:
raise NotImplementedError("Initial types must be specified.")
if initial_types is None:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def one_hot_encoder_supports_string():
return pv.Version(vers) >= pv.Version("0.20.0")


def skl12():
# pv.Version does not work with development versions
vers = ".".join(sklearn_version.split(".")[:2])
return pv.Version(vers) >= pv.Version("1.2")


class TestOnnxHelper(unittest.TestCase):
def get_model(self, model):
try:
Expand Down Expand Up @@ -73,6 +79,7 @@ def test_onnx_helper_load_save(self):
not one_hot_encoder_supports_string(),
reason="OneHotEncoder did not have categories_ before 0.20",
)
@unittest.skipIf(not skl12(), reason="sparse_output")
def test_onnx_helper_load_save_init(self):
model = make_pipeline(
Binarizer(),
Expand Down Expand Up @@ -105,6 +112,7 @@ def test_onnx_helper_load_save_init(self):
not one_hot_encoder_supports_string(),
reason="OneHotEncoder did not have categories_ before 0.20",
)
@unittest.skipIf(not skl12(), reason="sparse_output")
def test_onnx_helper_load_save_init_meta(self):
model = make_pipeline(
Binarizer(), OneHotEncoder(sparse_output=False), StandardScaler()
Expand Down
8 changes: 8 additions & 0 deletions tests/test_sklearn_array_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import numpy as np
from onnxruntime import __version__ as ort_version
from sklearn import __version__ as sklearn_version
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import OneHotEncoder

Expand All @@ -20,11 +21,18 @@
from test_utils import dump_data_and_model, TARGET_OPSET


def skl12():
# pv.Version does not work with development versions
vers = ".".join(sklearn_version.split(".")[:2])
return pv.Version(vers) >= pv.Version("1.2")


class TestSklearnArrayFeatureExtractor(unittest.TestCase):
@unittest.skipIf(
ColumnTransformer is None or pv.Version(ort_version) <= pv.Version("0.4.0"),
reason="onnxruntime too old",
)
@unittest.skipIf(not skl12(), reason="sparse_output")
def test_array_feature_extractor(self):
data_to_cluster = pd.DataFrame(
[[1, 2, 3.5, 4.5], [1, 2, 1.7, 4.0], [2, 4, 2.4, 4.3], [2, 4, 2.5, 4.0]],
Expand Down
14 changes: 14 additions & 0 deletions tests/test_sklearn_calibrated_classifier_cv_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
from numpy.testing import assert_almost_equal
from onnxruntime import __version__ as ort_version
from sklearn import __version__ as sklearn_version
from sklearn.calibration import CalibratedClassifierCV
from sklearn.datasets import load_digits, load_iris
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
Expand Down Expand Up @@ -49,6 +50,12 @@
ort_version = ort_version.split("+")[0]


def skl12():
# pv.Version does not work with development versions
vers = ".".join(sklearn_version.split(".")[:2])
return pv.Version(vers) >= pv.Version("1.2")


class TestSklearnCalibratedClassifierCVConverters(unittest.TestCase):
@ignore_warnings(category=(FutureWarning, ConvergenceWarning, DeprecationWarning))
def test_model_calibrated_classifier_cv_float(self):
Expand Down Expand Up @@ -186,6 +193,7 @@ def test_model_calibrated_classifier_cv_isotonic_binary_knn(self):
pv.Version(ort_version) < pv.Version("0.5.0"), reason="not available"
)
@ignore_warnings(category=(FutureWarning, ConvergenceWarning, DeprecationWarning))
@unittest.skipIf(not skl12(), reason="base_estimator")
def test_model_calibrated_classifier_cv_logistic_regression(self):
data = load_iris()
X, y = data.data, data.target
Expand All @@ -210,6 +218,7 @@ def test_model_calibrated_classifier_cv_logistic_regression(self):
pv.Version(ort_version) < pv.Version("0.5.0"), reason="not available"
)
@ignore_warnings(category=(FutureWarning, ConvergenceWarning, DeprecationWarning))
@unittest.skipIf(not skl12(), reason="base_estimator")
def test_model_calibrated_classifier_cv_rf(self):
data = load_iris()
X, y = data.data, data.target
Expand All @@ -234,6 +243,7 @@ def test_model_calibrated_classifier_cv_rf(self):
pv.Version(ort_version) < pv.Version("0.5.0"), reason="not available"
)
@ignore_warnings(category=(FutureWarning, ConvergenceWarning, DeprecationWarning))
@unittest.skipIf(not skl12(), reason="base_estimator")
def test_model_calibrated_classifier_cv_gbt(self):
data = load_iris()
X, y = data.data, data.target
Expand All @@ -259,6 +269,7 @@ def test_model_calibrated_classifier_cv_gbt(self):
pv.Version(ort_version) < pv.Version("0.5.0"), reason="not available"
)
@ignore_warnings(category=(FutureWarning, ConvergenceWarning, DeprecationWarning))
@unittest.skipIf(not skl12(), reason="base_estimator")
def test_model_calibrated_classifier_cv_hgbt(self):
data = load_iris()
X, y = data.data, data.target
Expand Down Expand Up @@ -308,6 +319,7 @@ def test_model_calibrated_classifier_cv_tree(self):
)
@unittest.skipIf(apply_less is None, reason="onnxconverter-common old")
@ignore_warnings(category=(FutureWarning, ConvergenceWarning, DeprecationWarning))
@unittest.skipIf(not skl12(), reason="base_estimator")
def test_model_calibrated_classifier_cv_svc(self):
data = load_iris()
X, y = data.data, data.target
Expand All @@ -330,6 +342,7 @@ def test_model_calibrated_classifier_cv_svc(self):
)
@unittest.skipIf(apply_less is None, reason="onnxconverter-common old")
@ignore_warnings(category=(FutureWarning, ConvergenceWarning, DeprecationWarning))
@unittest.skipIf(not skl12(), reason="base_estimator")
def test_model_calibrated_classifier_cv_linearsvc(self):
data = load_iris()
X, y = data.data, data.target
Expand All @@ -354,6 +367,7 @@ def test_model_calibrated_classifier_cv_linearsvc(self):
)
@unittest.skipIf(apply_less is None, reason="onnxconverter-common old")
@ignore_warnings(category=(FutureWarning, ConvergenceWarning, DeprecationWarning))
@unittest.skipIf(not skl12(), reason="base_estimator")
def test_model_calibrated_classifier_cv_linearsvc2(self):
data = load_iris()
X, y = data.data, data.target
Expand Down

0 comments on commit 19d24b2

Please sign in to comment.