Skip to content

Commit

Permalink
fix tfidf
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Jan 23, 2024
1 parent 5ace336 commit 6ef3261
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 23 deletions.
8 changes: 4 additions & 4 deletions skl2onnx/common/_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def get_default_opset_for_domain(domain):
if domain == "":
return main_opset
if domain == "ai.onnx.ml":
if main_opset >= 16:
if main_opset >= 18:
return 3
if main_opset < 6:
return 1
return 2
if main_opset >= 6:
return 2
return 1
if domain == "ai.onnx.training":
return 1
return None
Expand Down
27 changes: 9 additions & 18 deletions skl2onnx/operator_converters/tfidf_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,15 @@ def convert_sklearn_tfidf_transformer(
# code scikit-learn
# np.log(X.data, X.data) --> does not apply on null coefficient
# X.data += 1
# ONNX does not support sparse tensors before opset < 11
# approximated by X.data += 1 --> np.log(X.data, X.data)
if operator.target_opset < 11:
plus1 = scope.get_unique_variable_name("plus1")
C = operator.inputs[0].type.shape[1]
ones = scope.get_unique_variable_name("ones")
cst = np.ones((C,), dtype=float_type)
container.add_initializer(ones, proto_dtype, [C], cst.flatten())
apply_add(scope, data + [ones], plus1, container, broadcast=1)
plus1logged = scope.get_unique_variable_name("plus1logged")
apply_log(scope, plus1, plus1logged, container)
data = [plus1logged]
else:
# sparse containers have not yet been implemented.
raise RuntimeError(
"ONNX does not support sparse tensors before opset < 11, "
"sublinear_tf must be False."
)
plus1 = scope.get_unique_variable_name("plus1")
C = operator.inputs[0].type.shape[1]
ones = scope.get_unique_variable_name("ones")
cst = np.ones((C,), dtype=float_type)
container.add_initializer(ones, proto_dtype, [C], cst.flatten())
apply_add(scope, data + [ones], plus1, container, broadcast=1)
plus1logged = scope.get_unique_variable_name("plus1logged")
apply_log(scope, plus1, plus1logged, container)
data = [plus1logged]

if op.use_idf:
cst = op.idf_.astype(float_type)
Expand Down
68 changes: 67 additions & 1 deletion tests/test_issues_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

class TestInvestigate(unittest.TestCase):
def test_issue_1053(self):
import onnxruntime as rt
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
import onnxruntime as rt
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn

Expand Down Expand Up @@ -39,6 +39,72 @@ def test_issue_1053(self):
] # Select a single sample.
self.assertEqual(len(pred_onx.tolist()), 1)

def test_issue_1055(self):
import numpy as np
from numpy.testing import assert_almost_equal
import sklearn.feature_extraction.text
import sklearn.linear_model
import sklearn.pipeline
import onnxruntime as rt
import skl2onnx.common.data_types

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'skl2onnx.common.data_types' is imported with both 'import' and 'import from'.

lr = sklearn.linear_model.LogisticRegression(
C=100,
multi_class="multinomial",
solver="sag",
class_weight="balanced",
n_jobs=-1,
)
tf = sklearn.feature_extraction.text.TfidfVectorizer(
token_pattern="\\w+|[^\\w\\s]+",
ngram_range=(1, 1),
max_df=1.0,
min_df=1,
sublinear_tf=True,
)

pipe = sklearn.pipeline.Pipeline([("transformer", tf), ("logreg", lr)])

corpus = [
"This is the first document.",
"This document is the second document.",
"And this is the third one.",
"Is this the first document?",
"more text",
"$words",
"I keep writing things",
"how many documents now?",
"this is a really long setence",
"is this a final document?",
]
labels = ["1", "2", "1", "2", "1", "2", "1", "2", "1", "2"]

pipe.fit(corpus, labels)

onx = skl2onnx.convert_sklearn(
pipe,
"a model",
initial_types=[
("input", skl2onnx.common.data_types.StringTensorType([None, 1]))
],
target_opset=19,
options={"zipmap": False},
)
for d in onx.opset_import:
if d.domain == "":
self.assertEqual(d.version, 19)
elif d.domain == "com.microsoft":
self.assertEqual(d.version, 1)
elif d.domain == "ai.onnx.ml":
self.assertEqual(d.version, 1)

expected = pipe.predict_proba(corpus)
sess = rt.InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
got = sess.run(None, {"input": np.array(corpus).reshape((-1, 1))})
assert_almost_equal(expected, got[1], decimal=2)


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 6ef3261

Please sign in to comment.