From 84046e4ca6218348268b06fa18923c3601aa9e32 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= <xadupre@users.noreply.github.com>
Date: Thu, 19 Oct 2023 14:33:18 +0200
Subject: [PATCH] Fixes FeatureHasher when dealing with a list of strings as
 input (#1025)

* Fixes FeatureHasher

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* update test and docs

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* update featurehasher

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* fix import

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* encode utf-8

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* add options

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* fix target opset

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* finalize the converter

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* fix feature hasher

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* fix encoding issue

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* black

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* fix unit test

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* disable test on old onnx

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* update example

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

---------

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
---
 docs/tutorial/plot_transformer_discrepancy.py |   2 +-
 docs/tutorial/plot_weird_pandas_and_hash.py   | 347 ++++++++++++++++++
 requirements-dev.txt                          |   1 +
 skl2onnx/common/_container.py                 |   2 +-
 .../operator_converters/feature_hasher.py     | 109 +++++-
 .../operator_converters/text_vectoriser.py    |   2 +-
 skl2onnx/shape_calculators/feature_hasher.py  |   2 +-
 skl2onnx/sklapi/sklearn_text.py               |   4 +-
 tests/test_sklearn_feature_hasher.py          | 336 ++++++++++++++++-
 .../test_utils/reference_implementation_ml.py |   2 +-
 .../reference_implementation_text.py          |   2 +-
 .../reference_implementation_tree.py          |   2 +-
 tests/test_utils/utils_backend_onnx.py        |   4 +-
 13 files changed, 791 insertions(+), 24 deletions(-)
 create mode 100644 docs/tutorial/plot_weird_pandas_and_hash.py

diff --git a/docs/tutorial/plot_transformer_discrepancy.py b/docs/tutorial/plot_transformer_discrepancy.py
index 348ab070e..11098f5f5 100644
--- a/docs/tutorial/plot_transformer_discrepancy.py
+++ b/docs/tutorial/plot_transformer_discrepancy.py
@@ -50,7 +50,7 @@ def print_sparse_matrix(m):
 def diff(a, b):
     if a.shape != b.shape:
         raise ValueError(
-            f"Cannot compare matrices with different shapes " f"{a.shape} != {b.shape}."
+            f"Cannot compare matrices with different shapes {a.shape} != {b.shape}."
         )
     d = numpy.abs(a - b).sum() / a.size
     return d
diff --git a/docs/tutorial/plot_weird_pandas_and_hash.py b/docs/tutorial/plot_weird_pandas_and_hash.py
new file mode 100644
index 000000000..5c7526dc3
--- /dev/null
+++ b/docs/tutorial/plot_weird_pandas_and_hash.py
@@ -0,0 +1,347 @@
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+FeatureHasher, pandas values and unexpected discrepancies
+=========================================================
+
+A game of finding it goes wrong and there are multiple places.
+
+
+Initial example
++++++++++++++++
+"""
+import logging
+import numpy as np
+from pandas import DataFrame
+from onnxruntime import InferenceSession, SessionOptions
+from onnxruntime_extensions import get_library_path
+from sklearn.feature_extraction import FeatureHasher
+from sklearn.compose import ColumnTransformer
+from sklearn.pipeline import Pipeline
+from sklearn.ensemble import GradientBoostingClassifier
+from skl2onnx import to_onnx
+from skl2onnx.common.data_types import StringTensorType
+
+log = logging.getLogger("skl2onnx")
+log.setLevel(logging.ERROR)
+
+
+df = DataFrame(
+    {
+        "Cat1": ["a", "b", "d", "abd", "e", "z", "ez"],
+        "Cat2": ["A", "B", "D", "ABD", "e", "z", "ez"],
+        "Label": [1, 1, 0, 0, 1, 0, 0],
+    }
+)
+
+cat_features = [c for c in df.columns if "Cat" in c]
+X_train = df[cat_features]
+
+X_train["cat_features"] = df[cat_features].values.tolist()
+X_train = X_train.drop(cat_features, axis=1)
+y_train = df["Label"]
+
+pipe = Pipeline(
+    steps=[
+        (
+            "preprocessor",
+            ColumnTransformer(
+                [
+                    (
+                        "cat_preprocessor",
+                        FeatureHasher(
+                            n_features=8,
+                            input_type="string",
+                            alternate_sign=False,
+                            dtype=np.float32,
+                        ),
+                        "cat_features",
+                    )
+                ],
+                sparse_threshold=0.0,
+            ),
+        ),
+        ("classifier", GradientBoostingClassifier(n_estimators=2, max_depth=2)),
+    ],
+)
+pipe.fit(X_train, y_train)
+
+
+###################################
+# Conversion to ONNX.
+
+onx = to_onnx(
+    pipe,
+    initial_types=[("cat_features", StringTensorType([None, None]))],
+    options={"zipmap": False},
+)
+
+###################################
+# There are many discrepancies?
+
+expected_proba = pipe.predict_proba(X_train)
+sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
+
+
+got = sess.run(None, dict(cat_features=X_train.values))
+
+
+print("expected probabilities")
+print(expected_proba)
+
+print("onnx probabilities")
+print(got[1])
+
+#########################################
+# Let's check the feature hasher
+# ++++++++++++++++++++++++++++++
+#
+# We just remove the classifier.
+
+pipe_hash = Pipeline(
+    steps=[
+        (
+            "preprocessor",
+            ColumnTransformer(
+                [
+                    (
+                        "cat_preprocessor",
+                        FeatureHasher(
+                            n_features=8,
+                            input_type="string",
+                            alternate_sign=False,
+                            dtype=np.float32,
+                        ),
+                        "cat_features",
+                    )
+                ],
+                sparse_threshold=0.0,
+            ),
+        ),
+    ],
+)
+pipe_hash.fit(X_train, y_train)
+
+onx = to_onnx(
+    pipe_hash,
+    initial_types=[("cat_features", StringTensorType([None, None]))],
+    options={"zipmap": False},
+)
+
+expected = pipe_hash.transform(X_train)
+sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
+
+
+got = sess.run(None, dict(cat_features=X_train.values))
+
+
+print("expected hashed features")
+print(expected)
+
+print("onnx hashed features")
+print(got[0])
+
+#######################################
+# Nothing seems to be working.
+#
+# First proposal
+# ++++++++++++++
+#
+# The instruction
+# ``X_train["cat_features"] = df[cat_features].values.tolist()``
+# creates a DataFrame with on column of a lists of two values.
+# The type list is expected by scikit-learn and it can process a variable
+# number of elements per list. onnxruntime cannot do that.
+# It must be changed into the following.
+
+pipe_hash = Pipeline(
+    steps=[
+        (
+            "preprocessor",
+            ColumnTransformer(
+                [
+                    (
+                        "cat_preprocessor1",
+                        FeatureHasher(
+                            n_features=8,
+                            input_type="string",
+                            alternate_sign=False,
+                            dtype=np.float32,
+                        ),
+                        [0],
+                    ),
+                    (
+                        "cat_preprocessor2",
+                        FeatureHasher(
+                            n_features=8,
+                            input_type="string",
+                            alternate_sign=False,
+                            dtype=np.float32,
+                        ),
+                        [1],
+                    ),
+                ],
+                sparse_threshold=0.0,
+            ),
+        ),
+    ],
+)
+
+X_train_skl = df[cat_features].copy()
+for c in cat_features:
+    X_train_skl[c] = X_train_skl[c].values.tolist()
+
+pipe_hash.fit(X_train_skl.values, y_train)
+
+onx = to_onnx(
+    pipe_hash,
+    initial_types=[
+        ("cat1", StringTensorType([None, 1])),
+        ("cat2", StringTensorType([None, 1])),
+    ],
+    options={"zipmap": False},
+)
+
+
+expected = pipe_hash.transform(X_train_skl.values)
+sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
+
+
+got = sess.run(
+    None,
+    dict(
+        cat1=df["Cat1"].values.reshape((-1, 1)), cat2=df["Cat2"].values.reshape((-1, 1))
+    ),
+)
+
+
+print("expected fixed hashed features")
+print(expected)
+
+print("onnx fixed hashed features")
+print(got[0])
+
+###########################################
+# This is not the original pipeline. It has 16 columns instead of 8
+# but it does produce the same results.
+# One option would be to add the first 8 columns to the other 8
+# by using a custom converter.
+#
+# Second proposal
+# +++++++++++++++
+#
+# We use the same initial pipeline but we tweak the input
+# onnxruntime receives.
+
+pipe_hash = Pipeline(
+    steps=[
+        (
+            "preprocessor",
+            ColumnTransformer(
+                [
+                    (
+                        "cat_preprocessor",
+                        FeatureHasher(
+                            n_features=8,
+                            input_type="string",
+                            alternate_sign=False,
+                            dtype=np.float32,
+                        ),
+                        "cat_features",
+                    )
+                ],
+                sparse_threshold=0.0,
+            ),
+        ),
+    ],
+)
+pipe_hash.fit(X_train, y_train)
+
+onx = to_onnx(
+    pipe_hash,
+    initial_types=[("cat_features", StringTensorType([None, 1]))],
+    options={"zipmap": False, "preprocessor__cat_preprocessor__separator": "#"},
+)
+
+expected = pipe_hash.transform(X_train)
+
+
+so = SessionOptions()
+so.register_custom_ops_library(get_library_path())
+sess = InferenceSession(onx.SerializeToString(), so, providers=["CPUExecutionProvider"])
+
+# We merged both columns cat1 and cat2 into a single cat_features.
+df_fixed = DataFrame()
+df_fixed["cat_features"] = np.array([f"{a}#{b}" for a, b in X_train["cat_features"]])
+
+got = sess.run(None, {"cat_features": df_fixed[["cat_features"]].values})
+
+print("expected original hashed features")
+print(expected)
+
+print("onnx fixed original hashed features")
+print(got[0])
+
+############################################
+# It works now.
+#
+# Sparsity?
+# +++++++++
+#
+# Let's try with the classifier now and no `sparse_threshold=0.0`.
+
+pipe = Pipeline(
+    steps=[
+        (
+            "preprocessor",
+            ColumnTransformer(
+                [
+                    (
+                        "cat_preprocessor",
+                        FeatureHasher(
+                            n_features=8,
+                            input_type="string",
+                            alternate_sign=False,
+                            dtype=np.float32,
+                        ),
+                        "cat_features",
+                    )
+                ],
+                # sparse_threshold=0.0,
+            ),
+        ),
+        ("classifier", GradientBoostingClassifier(n_estimators=2, max_depth=2)),
+    ],
+)
+pipe.fit(X_train, y_train)
+expected = pipe.predict_proba(X_train)
+
+
+onx = to_onnx(
+    pipe,
+    initial_types=[("cat_features", StringTensorType([None, 1]))],
+    options={"zipmap": False, "preprocessor__cat_preprocessor__separator": "#"},
+)
+
+so = SessionOptions()
+so.register_custom_ops_library(get_library_path())
+sess = InferenceSession(onx.SerializeToString(), so, providers=["CPUExecutionProvider"])
+got = sess.run(None, {"cat_features": df_fixed[["cat_features"]].values})
+
+
+print("expected probabilies")
+print(expected)
+
+print("onnx probabilies")
+print(got[1])
+
+###########################################
+# scikit-learn keeps the sparse outputs from
+# the FeatureHasher. onnxruntime does not support
+# sparse features. This may have an impact on the conversion
+# if the model next to this step makes a difference between a
+# missing sparse value and zero.
+# That does not seem to be the case for this model but
+# other models or libraries may behave differently.
+
+print(pipe.steps[0][-1].transform(X_train))
diff --git a/requirements-dev.txt b/requirements-dev.txt
index ea181f954..d0ebc86f0 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,5 +1,6 @@
 # tests
 black
+onnxruntime-extensions
 pandas
 py-cpuinfo
 pybind11
diff --git a/skl2onnx/common/_container.py b/skl2onnx/common/_container.py
index d242ec8b1..ce42e43b6 100644
--- a/skl2onnx/common/_container.py
+++ b/skl2onnx/common/_container.py
@@ -629,7 +629,7 @@ def add_node(
             attrs["axes"] is None or not isinstance(attrs["axes"], (list, np.ndarray))
         ):
             raise TypeError(
-                f"axes must be a list or an array not " f"{type(attrs['axes'])}."
+                f"axes must be a list or an array not {type(attrs['axes'])}."
             )
         if name is None or not isinstance(name, str) or name == "":
             name = f"N{len(self.nodes)}"
diff --git a/skl2onnx/operator_converters/feature_hasher.py b/skl2onnx/operator_converters/feature_hasher.py
index 4183422f3..2c652cfb1 100644
--- a/skl2onnx/operator_converters/feature_hasher.py
+++ b/skl2onnx/operator_converters/feature_hasher.py
@@ -3,6 +3,7 @@
 import numpy as np
 from onnx import TensorProto
 from onnx.helper import make_tensor
+from onnx.numpy_helper import from_array
 from ..common._registration import register_converter
 from ..common._topology import Scope, Operator
 from ..common._container import ModelComponentContainer
@@ -11,7 +12,6 @@
 def convert_sklearn_feature_hasher(
     scope: Scope, operator: Operator, container: ModelComponentContainer
 ):
-    X = operator.inputs[0]
     out = operator.outputs
     op = operator.raw_operator
     if op.input_type != "string":
@@ -20,10 +20,90 @@ def convert_sklearn_feature_hasher(
             f"input_type='string' not {op.input_type!r}."
         )
 
+    # If option separator is not None, the converter assumes the input
+    # is one string column, each element is a list of strings concatenated
+    # with this separator.
+    options = container.get_options(op, dict(separator=None))
+    separator = options.get("separator", None)
+
+    if separator is not None:
+        # Let's split the columns
+        delimiter = scope.get_unique_variable_name("delimiter")
+        container.add_initializer(
+            delimiter, TensorProto.STRING, [], [separator.encode("utf-8")]
+        )
+        empty_string = scope.get_unique_variable_name("empty_string")
+        container.add_initializer(
+            empty_string, TensorProto.STRING, [], ["".encode("utf-8")]
+        )
+        skip_empty = scope.get_unique_variable_name("delimiter")
+        container.add_initializer(skip_empty, TensorProto.BOOL, [], [False])
+        flat_shape = scope.get_unique_variable_name("flat_shape")
+        container.add_initializer(flat_shape, TensorProto.INT64, [1], [-1])
+        zero = scope.get_unique_variable_name("zero")
+        container.add_initializer(zero, TensorProto.INT64, [1], [0])
+        one = scope.get_unique_variable_name("one")
+        container.add_initializer(one, TensorProto.INT64, [1], [1])
+
+        to_concat = []
+        for i, col_to_split in enumerate(operator.inputs):
+            reshaped = scope.get_unique_variable_name(f"reshaped{i}")
+            container.add_node(
+                "Reshape", [col_to_split.full_name, flat_shape], [reshaped]
+            )
+            out_indices = scope.get_unique_variable_name(f"out_indices{i}")
+            out_text = scope.get_unique_variable_name(f"out_text{i}")
+            out_shape = scope.get_unique_variable_name(f"out_shape{i}")
+            container.add_node(
+                "StringSplit",
+                [reshaped, delimiter, skip_empty],
+                [out_indices, out_text, out_shape],
+                op_domain="ai.onnx.contrib",
+                op_version=1,
+            )
+            shape = scope.get_unique_variable_name(f"shape{i}")
+            container.add_node("Shape", [col_to_split.full_name], [shape])
+
+            emptyi = scope.get_unique_variable_name(f"emptyi{i}")
+            container.add_node(
+                "ConstantOfShape",
+                [out_shape],
+                [emptyi],
+                value=from_array(np.array([0], dtype=np.int64)),
+            )
+            emptyb = scope.get_unique_variable_name(f"emptyb{i}")
+            container.add_node("Cast", [emptyi], [emptyb], to=TensorProto.BOOL)
+            emptys = scope.get_unique_variable_name(f"emptys{i}")
+            container.add_node("Where", [emptyb, empty_string, empty_string], [emptys])
+            flat_split = scope.get_unique_variable_name(f"flat_split{i}")
+            container.add_node(
+                "ScatterND", [emptys, out_indices, out_text], [flat_split]
+            )
+            # shape_1 = scope.get_unique_variable_name(f"shape_1{i}")
+            # container.add_node("Concat", [shape, flat_shape], [shape_1], axis=0)
+
+            split = scope.get_unique_variable_name(f"split{i}")
+            to_concat.append(split)
+            # container.add_node("Reshape", [flat_split, shape_1], [split])
+            container.add_node("Identity", [flat_split], [split])
+        if len(to_concat) == 1:
+            input_hasher = to_concat[0]
+        else:
+            input_hasher = scope.get_unique_variable_name("concatenated")
+            container.add_node("Concat", to_concat, [input_hasher], axis=1)
+    elif len(operator.inputs) == 1:
+        X = operator.inputs[0]
+        input_hasher = X.full_name
+    else:
+        raise RuntimeError(
+            f"Only one input is expected but received "
+            f"{[i.name for i in operator.inputs]}."
+        )
+
     hashed_ = scope.get_unique_variable_name("hashed_")
     container.add_node(
         "MurmurHash3",
-        X.full_name,
+        input_hasher,
         hashed_,
         positive=0,
         seed=0,
@@ -66,8 +146,6 @@ def convert_sklearn_feature_hasher(
 
     new_shape = scope.get_unique_variable_name("new_shape")
     container.add_initializer(new_shape, TensorProto.INT64, [2], [-1, 1])
-    new_shape2 = scope.get_unique_variable_name("new_shape2")
-    container.add_initializer(new_shape2, TensorProto.INT64, [2], [1, -1])
 
     # values
     if op.alternate_sign:
@@ -112,21 +190,36 @@ def convert_sklearn_feature_hasher(
         "ScatterElements", [zerot, indices_reshaped, values_reshaped], final, axis=1
     )
 
-    # at this point, every string has been processed as if it was in
+    # at this point, every string has been processed as if it were in
     # in a single columns.
     # in case there is more than one column, we need to reduce over
     # the last dimension
     input_shape = scope.get_unique_variable_name("input_shape")
-    container.add_node("Shape", X.full_name, input_shape)
+    container.add_node("Shape", input_hasher, input_shape)
     shape_not_last = scope.get_unique_variable_name("shape_not_last")
     container.add_node("Slice", [input_shape, zero, mone], shape_not_last)
     final_shape = scope.get_unique_variable_name("final_last")
     container.add_node("Concat", [shape_not_last, mone, nf], final_shape, axis=0)
     final_reshaped = scope.get_unique_variable_name("final_reshaped")
     container.add_node("Reshape", [final, final_shape], final_reshaped)
+
+    if op.dtype == np.float32:
+        to = TensorProto.FLOAT
+    elif op.dtype == np.float64:
+        to = TensorProto.DOUBLE
+    elif op.dtype in (np.int32, np.uint32, np.int64):
+        to = TensorProto.INT64
+    else:
+        raise RuntimeError(
+            f"Converter is not implemented for FeatureHasher.dtype={op.dtype}."
+        )
+    final_reshaped_cast = scope.get_unique_variable_name("final_reshaped_cast")
+    container.add_node("Cast", [final_reshaped], final_reshaped_cast, to=to)
     container.add_node(
-        "ReduceSum", [final_reshaped, mtwo], out[0].full_name, keepdims=0
+        "ReduceSum", [final_reshaped_cast, mtwo], out[0].full_name, keepdims=0
     )
 
 
-register_converter("SklearnFeatureHasher", convert_sklearn_feature_hasher)
+register_converter(
+    "SklearnFeatureHasher", convert_sklearn_feature_hasher, options={"separator": None}
+)
diff --git a/skl2onnx/operator_converters/text_vectoriser.py b/skl2onnx/operator_converters/text_vectoriser.py
index b0244ba65..6c26f11fe 100644
--- a/skl2onnx/operator_converters/text_vectoriser.py
+++ b/skl2onnx/operator_converters/text_vectoriser.py
@@ -288,7 +288,7 @@ def convert_sklearn_text_vectorizer(
     for w in stop_words:
         if not isinstance(w, str):
             raise TypeError(
-                f"One stop word is not a string {w!r} " f"in stop_words={stop_words}."
+                f"One stop word is not a string {w!r} in stop_words={stop_words}."
             )
 
     if op.lowercase or stop_words:
diff --git a/skl2onnx/shape_calculators/feature_hasher.py b/skl2onnx/shape_calculators/feature_hasher.py
index 9ba186e48..7835274cd 100644
--- a/skl2onnx/shape_calculators/feature_hasher.py
+++ b/skl2onnx/shape_calculators/feature_hasher.py
@@ -29,7 +29,7 @@ def calculate_sklearn_feature_hasher(operator):
         operator.outputs[0].type = Int64TensorType(shape=shape)
     else:
         raise RuntimeError(
-            f"Converter is not implemented for " f"FeatureHasher.dtype={model.dtype}."
+            f"Converter is not implemented for FeatureHasher.dtype={model.dtype}."
         )
 
 
diff --git a/skl2onnx/sklapi/sklearn_text.py b/skl2onnx/sklapi/sklearn_text.py
index 709398dfc..43b474b05 100644
--- a/skl2onnx/sklapi/sklearn_text.py
+++ b/skl2onnx/sklapi/sklearn_text.py
@@ -162,7 +162,7 @@ def fit(self, X, y=None):
         self.same_ = same
         if self.stop_words != same.stop_words:
             raise AssertionError(
-                f"Different stop_words {self.stop_words} " f"!= {same.stop_words}."
+                f"Different stop_words {self.stop_words} != {same.stop_words}."
             )
         update, dups = self._fix_vocabulary(same.vocabulary_, self.vocabulary_)
         self.updated_vocabulary_ = update
@@ -228,7 +228,7 @@ def fit(self, X, y=None):
         self.same_ = same
         if self.stop_words != same.stop_words:
             raise AssertionError(
-                f"Different stop_words {self.stop_words} " f"!= {same.stop_words}."
+                f"Different stop_words {self.stop_words} != {same.stop_words}."
             )
         update, dups = self._fix_vocabulary(same.vocabulary_, self.vocabulary_)
         self.updated_vocabulary_ = update
diff --git a/tests/test_sklearn_feature_hasher.py b/tests/test_sklearn_feature_hasher.py
index 50dfc8d4c..aab1fb037 100644
--- a/tests/test_sklearn_feature_hasher.py
+++ b/tests/test_sklearn_feature_hasher.py
@@ -6,8 +6,9 @@
 import unittest
 import packaging.version as pv
 import numpy as np
+from sklearn.utils._testing import assert_almost_equal
 from pandas import DataFrame
-from onnx import TensorProto
+from onnx import TensorProto, __version__ as onnx_version
 from onnx.helper import (
     make_model,
     make_node,
@@ -16,8 +17,17 @@
     make_opsetid,
 )
 from onnx.checker import check_model
-from onnxruntime import __version__ as ort_version
+
+try:
+    from onnx.reference import ReferenceEvaluator
+    from onnx.reference.op_run import OpRun
+except ImportError:
+    ReferenceEvaluator = None
+from onnxruntime import __version__ as ort_version, SessionOptions
 from sklearn.feature_extraction import FeatureHasher
+from sklearn.compose import ColumnTransformer
+from sklearn.pipeline import Pipeline
+from sklearn.tree import DecisionTreeClassifier
 from skl2onnx import to_onnx
 from skl2onnx.common.data_types import (
     StringTensorType,
@@ -25,7 +35,11 @@
     FloatTensorType,
     DoubleTensorType,
 )
-from test_utils import TARGET_OPSET, TARGET_IR, InferenceSessionEx as InferenceSession
+from test_utils import (
+    TARGET_OPSET,
+    TARGET_IR,
+    InferenceSessionEx as InferenceSession,
+)
 
 
 class TestSklearnFeatureHasher(unittest.TestCase):
@@ -203,7 +217,6 @@ def test_feature_hasher_dataframe(self):
         )
         model.fit(data)
         expected = model.transform(data).todense()
-        print(expected)
 
         model_onnx = to_onnx(
             model,
@@ -251,6 +264,319 @@ def test_feature_hasher_two_columns_unicode(self):
             if a != b:
                 raise AssertionError(f"Discrepancies at line {i}: {a} != {b}")
 
+    def test_feature_hasher_pipeline(self):
+        data = {
+            "Education": ["a", "b", "d", "abd"],
+            "Label": [1, 1, 0, 0],
+        }
+        df = DataFrame(data)
+
+        cat_features = ["Education"]
+        X_train = df[cat_features]
+
+        X_train["cat_features"] = df[cat_features].values.tolist()
+        X_train = X_train.drop(cat_features, axis=1)
+        y_train = df["Label"]
+
+        preprocessing_pipeline = ColumnTransformer(
+            [
+                (
+                    "cat_preprocessor",
+                    FeatureHasher(
+                        n_features=16,
+                        input_type="string",
+                        alternate_sign=False,
+                        dtype=np.float32,
+                    ),
+                    "cat_features",
+                )
+            ],
+            sparse_threshold=0.0,
+        )
+
+        complete_pipeline = Pipeline(
+            steps=[
+                ("preprocessor", preprocessing_pipeline),
+                ("classifier", DecisionTreeClassifier(max_depth=2)),
+            ],
+        )
+        complete_pipeline.fit(X_train, y_train)
+
+        # first check
+        model = FeatureHasher(
+            n_features=16,
+            input_type="string",
+            alternate_sign=False,
+            dtype=np.float32,
+        )
+        X_train_ort1 = X_train.values.reshape((-1, 1))
+        with self.assertRaises(TypeError):
+            np.asarray(model.transform(X_train_ort1).todense())
+        input_strings = ["a", "b", "d", "abd"]
+        X_train_ort2 = np.array(input_strings, dtype=object).reshape((-1, 1))
+        model.fit(X_train_ort2)
+        # type(X_train_ort2[0, 0]) == str != list == type(X_train_ort2[0, 0])
+        expected2 = np.asarray(model.transform(X_train_ort2).todense())
+        model_onnx = to_onnx(
+            model,
+            initial_types=[("cat_features", StringTensorType([None, 1]))],
+            target_opset=TARGET_OPSET,
+        )
+        sess = InferenceSession(
+            model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
+        )
+        got2 = sess.run(None, dict(cat_features=X_train_ort2))
+        assert_almost_equal(expected2, got2[0])
+        got1 = sess.run(None, dict(cat_features=X_train_ort1))
+        with self.assertRaises(AssertionError):
+            assert_almost_equal(expected2, got1[0])
+
+        # check hash
+        X_train_ort = X_train.values
+        expected = np.asarray(
+            complete_pipeline.steps[0][-1]
+            .transformers_[0][1]
+            .transform(X_train.values.ravel())
+            .todense()
+        )
+        model_onnx = to_onnx(
+            complete_pipeline.steps[0][-1].transformers_[0][1],
+            initial_types=[("cat_features", StringTensorType([None, 1]))],
+            target_opset=TARGET_OPSET,
+        )
+        sess = InferenceSession(
+            model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
+        )
+        got = sess.run(None, dict(cat_features=X_train_ort))
+        with self.assertRaises(AssertionError):
+            assert_almost_equal(expected, got[0])
+        got = sess.run(None, dict(cat_features=X_train_ort2))
+        assert_almost_equal(expected, got[0])
+
+        # transform
+        X_train_ort = X_train.values
+        expected = complete_pipeline.steps[0][-1].transform(X_train)
+        model_onnx = to_onnx(
+            complete_pipeline.steps[0][-1],
+            initial_types=[("cat_features", StringTensorType([None, 1]))],
+            target_opset=TARGET_OPSET,
+        )
+        sess = InferenceSession(
+            model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
+        )
+        got = sess.run(None, dict(cat_features=X_train_ort))
+        with self.assertRaises(AssertionError):
+            assert_almost_equal(expected, got[0].astype(np.float64))
+        got = sess.run(None, dict(cat_features=X_train_ort2))
+        assert_almost_equal(expected, got[0].astype(np.float64))
+
+        # classifier
+        expected = complete_pipeline.predict_proba(X_train)
+        labels = complete_pipeline.predict(X_train)
+        model_onnx = to_onnx(
+            complete_pipeline,
+            initial_types=[("cat_features", StringTensorType([None, 1]))],
+            target_opset=TARGET_OPSET,
+            options={"zipmap": False},
+        )
+
+        sess = InferenceSession(
+            model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
+        )
+        X_train_ort = X_train.values
+        got = sess.run(None, dict(cat_features=X_train_ort))
+        with self.assertRaises(AssertionError):
+            assert_almost_equal(expected, got[1].astype(np.float64))
+        got = sess.run(None, dict(cat_features=X_train_ort2))
+        assert_almost_equal(labels, got[0])
+
+    @unittest.skipIf(
+        pv.Version(onnx_version) < pv.Version("1.11"), reason="onnx is too old"
+    )
+    def test_feature_hasher_pipeline_list(self):
+        pipe_hash = Pipeline(
+            steps=[
+                (
+                    "preprocessor",
+                    ColumnTransformer(
+                        [
+                            (
+                                "cat_features",
+                                FeatureHasher(
+                                    n_features=8,
+                                    input_type="string",
+                                    alternate_sign=False,
+                                    dtype=np.float32,
+                                ),
+                                "cat_features",
+                            ),
+                        ],
+                        sparse_threshold=0.0,
+                    ),
+                ),
+            ],
+        )
+
+        df = DataFrame(
+            {
+                "Cat1": ["a", "b", "d", "abd", "e", "z", "ez"],
+                "Cat2": ["A", "B", "D", "ABD", "e", "z", "ez"],
+            }
+        )
+
+        cat_features = [c for c in df.columns if "Cat" in c]
+        X_train = df[cat_features].copy()
+        X_train["cat_features"] = df[cat_features].values.tolist()
+        X_train = X_train.drop(cat_features, axis=1)
+        pipe_hash.fit(X_train)
+        expected = pipe_hash.transform(X_train)
+
+        onx = to_onnx(
+            pipe_hash,
+            initial_types=[("cat_features", StringTensorType([None, 1]))],
+            options={FeatureHasher: {"separator": "#"}},
+            target_opset=TARGET_OPSET,
+        )
+
+        dfx = df.copy()
+        dfx["cat_features"] = df[cat_features].agg("#".join, axis=1)
+        feeds = dict(cat_features=dfx["cat_features"].values.reshape((-1, 1)))
+
+        if ReferenceEvaluator is not None:
+
+            class StringSplit(OpRun):
+                op_domain = "ai.onnx.contrib"
+
+                def _run(self, input, separator, skip_empty, **kwargs):
+                    # kwargs should be null, bug in onnx?
+                    delimiter = (
+                        str(separator[0])
+                        if len(separator.shape) > 0
+                        else str(separator)
+                    )
+                    skip_empty = (
+                        bool(skip_empty[0])
+                        if len(skip_empty.shape)
+                        else bool(skip_empty)
+                    )
+                    texts = []
+                    indices = []
+                    max_split = 0
+                    for row, text in enumerate(input):
+                        if not text:
+                            continue
+                        res = text.split(delimiter)
+                        if skip_empty:
+                            res = [t for t in res if t]
+                        texts.extend(res)
+                        max_split = max(max_split, len(res))
+                        indices.extend((row, i) for i in range(len(res)))
+                    return (
+                        np.array(indices, dtype=np.int64),
+                        np.array(texts),
+                        np.array([len(input), max_split], dtype=np.int64),
+                    )
+
+            class MurmurHash3(OpRun):
+                op_domain = "com.microsoft"
+
+                @staticmethod
+                def rotl(num, bits):
+                    bit = num & (1 << (bits - 1))
+                    num <<= 1
+                    if bit:
+                        num |= 1
+                    num &= 2**bits - 1
+                    return num
+
+                @staticmethod
+                def fmix(h: int):
+                    h ^= h >> 16
+                    h = np.uint32(
+                        (int(h) * int(0x85EBCA6B)) % (int(np.iinfo(np.uint32).max) + 1)
+                    )
+                    h ^= h >> 13
+                    h = np.uint32(
+                        (int(h) * int(0xC2B2AE35)) % (int(np.iinfo(np.uint32).max) + 1)
+                    )
+                    h ^= h >> 16
+                    return h
+
+                @staticmethod
+                def MurmurHash3_x86_32(data, seed):
+                    le = len(data)
+                    nblocks = le // 4
+                    h1 = seed
+
+                    c1 = 0xCC9E2D51
+                    c2 = 0x1B873593
+
+                    iblock = nblocks * 4
+
+                    for i in range(-nblocks, 0):
+                        k1 = np.uint32(data[iblock + i])
+                        k1 *= c1
+                        k1 = (k1, 15)
+                        k1 *= c2
+                        h1 ^= k1
+                        h1 = MurmurHash3.rotl(h1, 13)
+                        h1 = h1 * 5 + 0xE6546B64
+
+                    k1 = 0
+
+                    if le & 3 >= 3:
+                        k1 ^= np.uint32(data[iblock + 2]) << 16
+                    if le & 3 >= 2:
+                        k1 ^= np.uint32(data[iblock + 1]) << 8
+                    if le & 3 >= 1:
+                        k1 ^= np.uint32(data[iblock])
+                        k1 *= c1
+                        k1 = MurmurHash3.rotl(k1, 15)
+                        k1 *= c2
+                        h1 ^= k1
+
+                    h1 ^= le
+
+                    h1 = MurmurHash3.fmix(h1)
+                    return h1
+
+                def _run(self, x, positive: int = None, seed: int = None):
+                    x2 = x.reshape((-1,))
+                    y = np.empty(x2.shape, dtype=np.uint32)
+                    for i in range(y.shape[0]):
+                        b = x2[i].encode("utf-8")
+                        h = MurmurHash3.MurmurHash3_x86_32(b, seed)
+                        y[i] = h
+                    return (y.reshape(x.shape),)
+
+            ref = ReferenceEvaluator(onx, new_ops=[StringSplit, MurmurHash3])
+            got_py = ref.run(None, feeds)
+        else:
+            got_py = None
+
+        from onnxruntime_extensions import get_library_path
+
+        so = SessionOptions()
+        so.register_custom_ops_library(get_library_path())
+        sess = InferenceSession(
+            onx.SerializeToString(), so, providers=["CPUExecutionProvider"]
+        )
+        got = sess.run(None, feeds)
+        assert_almost_equal(expected, got[0])
+
+        if ReferenceEvaluator is not None:
+            # The pure python implementation does not correctly implement murmurhash3.
+            # There are issue with type int.
+            assert_almost_equal(expected.shape, got_py[0].shape)
+
 
 if __name__ == "__main__":
-    unittest.main()
+    import logging
+
+    logger = logging.getLogger("skl2onnx")
+    logger.setLevel(logging.ERROR)
+    logger = logging.getLogger("onnx-extended")
+    logger.setLevel(logging.ERROR)
+
+    TestSklearnFeatureHasher().test_feature_hasher_pipeline_list()
+    unittest.main(verbosity=2)
diff --git a/tests/test_utils/reference_implementation_ml.py b/tests/test_utils/reference_implementation_ml.py
index d25daec33..7e1ff7223 100644
--- a/tests/test_utils/reference_implementation_ml.py
+++ b/tests/test_utils/reference_implementation_ml.py
@@ -235,7 +235,7 @@ def _run(self, x, cats_int64s=None, cats_strings=None, zeros=None):
                             res[a, i, j] = 1.0
             else:
                 raise RuntimeError(
-                    f"This operator is not implemented " f"for " f"shape {x.shape}."
+                    f"This operator is not implemented for shape {x.shape}."
                 )
 
             if not self.zeros:
diff --git a/tests/test_utils/reference_implementation_text.py b/tests/test_utils/reference_implementation_text.py
index 0ff6497e0..f5b4d84a8 100644
--- a/tests/test_utils/reference_implementation_text.py
+++ b/tests/test_utils/reference_implementation_text.py
@@ -99,7 +99,7 @@ def _run_tokenization(text, stops, split, mark, pad_value):
                 res = np.array(res)
             else:
                 raise RuntimeError(  # pragma: no cover
-                    f"Only vector or matrices are supported " f"not shape {text.shape}."
+                    f"Only vector or matrices are supported not shape {text.shape}."
                 )
             return (res,)
 
diff --git a/tests/test_utils/reference_implementation_tree.py b/tests/test_utils/reference_implementation_tree.py
index 6cf7abbf0..d25b3e283 100644
--- a/tests/test_utils/reference_implementation_tree.py
+++ b/tests/test_utils/reference_implementation_tree.py
@@ -345,7 +345,7 @@ def _run(
             elif classlabels_strings is not None:
                 if len(classlabels_strings) == 1:
                     raise NotImplementedError(
-                        f"classlabels_strings={classlabels_strings}, " f"not supported."
+                        f"classlabels_strings={classlabels_strings}, not supported."
                     )
                 labels = np.array([classlabels_strings[i] for i in labels])
 
diff --git a/tests/test_utils/utils_backend_onnx.py b/tests/test_utils/utils_backend_onnx.py
index ff9cb5763..fbab2109b 100644
--- a/tests/test_utils/utils_backend_onnx.py
+++ b/tests/test_utils/utils_backend_onnx.py
@@ -102,7 +102,7 @@ def _run(self, data, axis=None, keepdims=None, select_last_index=None):
                         return (_argmax(data, axis=axis, keepdims=keepdims),)
                     except Exception as e:
                         raise RuntimeError(
-                            f"Issue with shape={data.shape} " f"and axis={axis}."
+                            f"Issue with shape={data.shape} and axis={axis}."
                         ) from e
                 raise NotImplementedError("Unused in sklearn-onnx.")
 
@@ -407,7 +407,7 @@ def _log_arg(self, a):
                 elements = a.ravel().tolist()
                 if len(elements) > 5:
                     elements = elements[:5]
-                    return f"{a.dtype}:{a.shape}:" f"{','.join(map(str, elements))}..."
+                    return f"{a.dtype}:{a.shape}:{','.join(map(str, elements))}..."
                 return f"{a.dtype}:{a.shape}:{elements}"
             if hasattr(a, "append"):
                 return ", ".join(map(self._log_arg, a))