diff --git a/skl2onnx/common/_topology.py b/skl2onnx/common/_topology.py index 23c3c1274..1197cb8b4 100644 --- a/skl2onnx/common/_topology.py +++ b/skl2onnx/common/_topology.py @@ -87,11 +87,11 @@ def get_default_opset_for_domain(domain): if domain == "": return main_opset if domain == "ai.onnx.ml": - if main_opset >= 16: + if main_opset >= 18: return 3 - if main_opset < 6: - return 1 - return 2 + if main_opset >= 6: + return 2 + return 1 if domain == "ai.onnx.training": return 1 return None diff --git a/skl2onnx/operator_converters/tfidf_transformer.py b/skl2onnx/operator_converters/tfidf_transformer.py index fe6c20239..a57288223 100644 --- a/skl2onnx/operator_converters/tfidf_transformer.py +++ b/skl2onnx/operator_converters/tfidf_transformer.py @@ -36,24 +36,15 @@ def convert_sklearn_tfidf_transformer( # code scikit-learn # np.log(X.data, X.data) --> does not apply on null coefficient # X.data += 1 - # ONNX does not support sparse tensors before opset < 11 - # approximated by X.data += 1 --> np.log(X.data, X.data) - if operator.target_opset < 11: - plus1 = scope.get_unique_variable_name("plus1") - C = operator.inputs[0].type.shape[1] - ones = scope.get_unique_variable_name("ones") - cst = np.ones((C,), dtype=float_type) - container.add_initializer(ones, proto_dtype, [C], cst.flatten()) - apply_add(scope, data + [ones], plus1, container, broadcast=1) - plus1logged = scope.get_unique_variable_name("plus1logged") - apply_log(scope, plus1, plus1logged, container) - data = [plus1logged] - else: - # sparse containers have not yet been implemented. - raise RuntimeError( - "ONNX does not support sparse tensors before opset < 11, " - "sublinear_tf must be False." - ) + plus1 = scope.get_unique_variable_name("plus1") + C = operator.inputs[0].type.shape[1] + ones = scope.get_unique_variable_name("ones") + cst = np.ones((C,), dtype=float_type) + container.add_initializer(ones, proto_dtype, [C], cst.flatten()) + apply_add(scope, data + [ones], plus1, container, broadcast=1) + plus1logged = scope.get_unique_variable_name("plus1logged") + apply_log(scope, plus1, plus1logged, container) + data = [plus1logged] if op.use_idf: cst = op.idf_.astype(float_type) diff --git a/tests/test_issues_2024.py b/tests/test_issues_2024.py index 5b86475b7..258304e5d 100644 --- a/tests/test_issues_2024.py +++ b/tests/test_issues_2024.py @@ -4,11 +4,11 @@ class TestInvestigate(unittest.TestCase): def test_issue_1053(self): - import onnxruntime as rt from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier + import onnxruntime as rt from skl2onnx.common.data_types import FloatTensorType from skl2onnx import convert_sklearn @@ -39,6 +39,72 @@ def test_issue_1053(self): ] # Select a single sample. self.assertEqual(len(pred_onx.tolist()), 1) + def test_issue_1055(self): + import numpy as np + from numpy.testing import assert_almost_equal + import sklearn.feature_extraction.text + import sklearn.linear_model + import sklearn.pipeline + import onnxruntime as rt + import skl2onnx.common.data_types + + lr = sklearn.linear_model.LogisticRegression( + C=100, + multi_class="multinomial", + solver="sag", + class_weight="balanced", + n_jobs=-1, + ) + tf = sklearn.feature_extraction.text.TfidfVectorizer( + token_pattern="\\w+|[^\\w\\s]+", + ngram_range=(1, 1), + max_df=1.0, + min_df=1, + sublinear_tf=True, + ) + + pipe = sklearn.pipeline.Pipeline([("transformer", tf), ("logreg", lr)]) + + corpus = [ + "This is the first document.", + "This document is the second document.", + "And this is the third one.", + "Is this the first document?", + "more text", + "$words", + "I keep writing things", + "how many documents now?", + "this is a really long setence", + "is this a final document?", + ] + labels = ["1", "2", "1", "2", "1", "2", "1", "2", "1", "2"] + + pipe.fit(corpus, labels) + + onx = skl2onnx.convert_sklearn( + pipe, + "a model", + initial_types=[ + ("input", skl2onnx.common.data_types.StringTensorType([None, 1])) + ], + target_opset=19, + options={"zipmap": False}, + ) + for d in onx.opset_import: + if d.domain == "": + self.assertEqual(d.version, 19) + elif d.domain == "com.microsoft": + self.assertEqual(d.version, 1) + elif d.domain == "ai.onnx.ml": + self.assertEqual(d.version, 1) + + expected = pipe.predict_proba(corpus) + sess = rt.InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = sess.run(None, {"input": np.array(corpus).reshape((-1, 1))}) + assert_almost_equal(expected, got[1], decimal=2) + if __name__ == "__main__": unittest.main(verbosity=2)