Skip to content

Commit

Permalink
fix many tiny bugs
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Jan 23, 2024
1 parent 6ef3261 commit 4c8ab22
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 11 deletions.
4 changes: 2 additions & 2 deletions docs/tests/test_documentation_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def run_test(self, fold: str, name: str, verbose=0) -> int:
def add_test_methods(cls):
this = os.path.abspath(os.path.dirname(__file__))
folds = [
os.path.normpath(os.path.join(this, "..", "docs", "examples")),
os.path.normpath(os.path.join(this, "..", "docs", "tutorial")),
os.path.normpath(os.path.join(this, "..", "examples")),
os.path.normpath(os.path.join(this, "..", "tutorial")),
]
for fold in folds:
found = os.listdir(fold)
Expand Down
16 changes: 10 additions & 6 deletions skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,14 +436,18 @@ def _parse_sklearn_grid_search_cv(scope, model, inputs, custom_parsers=None):


def _parse_sklearn_random_trees_embedding(scope, model, inputs, custom_parsers=None):
res = parse_sklearn(
scope, model.base_estimator_, inputs, custom_parsers=custom_parsers
)
if hasattr(model, "estimator_"):
est = model.estimator_
elif hasattr(model, "base_estimator_"):
est = model.base_estimator_
else:
raise RuntimeError(
f"Model {model} was not trained (unable to find the estimator {dir(model)})."
)
res = parse_sklearn(scope, est, inputs, custom_parsers=custom_parsers)
if len(res) != 1:
raise RuntimeError("A regressor only produces one output not %r." % res)
scope.replace_raw_operator(
model.base_estimator_, model, "SklearnRandomTreesEmbedding"
)
scope.replace_raw_operator(est, model, "SklearnRandomTreesEmbedding")
return res


Expand Down
6 changes: 6 additions & 0 deletions tests/test_issues_2024.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import unittest
import packaging.version as pv
from onnxruntime import __version__ as ort_version


class TestInvestigate(unittest.TestCase):
Expand Down Expand Up @@ -39,6 +41,10 @@ def test_issue_1053(self):
] # Select a single sample.
self.assertEqual(len(pred_onx.tolist()), 1)

@unittest.skipIf(
pv.Version(ort_version) < pv.Version("1.16.0"),
reason="opset 19 not implemented",
)
def test_issue_1055(self):
import numpy as np
from numpy.testing import assert_almost_equal
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sklearn_double_tensor_type_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def test_calibration_sigmoid_64(self):
self._common_classifier(
[
lambda: CalibratedClassifierCV(
base_estimator=LogisticRegression(), method="sigmoid"
estimator=LogisticRegression(), method="sigmoid"
)
],
"CalibratedClassifierCV",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sklearn_glm_classifier_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def test_model_logistic_regression_multi_class_no_intercept(self):

@ignore_warnings(category=(DeprecationWarning, ConvergenceWarning))
def test_model_logistic_regression_multi_class_lbfgs(self):
penalty = "l2" if _sklearn_version() < pv.Version("0.21.0") else "none"
penalty = "l2"
model, X = fit_classification_model(
linear_model.LogisticRegression(
solver="lbfgs", penalty=penalty, max_iter=10000
Expand Down
2 changes: 1 addition & 1 deletion tests_onnxmltools/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_xgb_classifier_reglog(self):
conv_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
res = sess.run(None, {"input": X.astype(np.float32)})
assert_almost_equal(xgb.predict_proba(X), res[1])
assert_almost_equal(xgb.predict_proba(X), res[1], decimal=4)
assert_almost_equal(xgb.predict(X), res[0])

@unittest.skipIf(StackingClassifier is None, reason="new in 0.22")
Expand Down

0 comments on commit 4c8ab22

Please sign in to comment.