Skip to content

Commit

Permalink
update to retain two original column
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Nov 8, 2023
1 parent 98f9888 commit 8d1e174
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
10 changes: 5 additions & 5 deletions .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
Py39-Onnx120-Rt1111-Skl11:
do.bench: '0'
python.version: '3.9'
numpy.version: '>=1.21.0'
numpy.version: '>=1.21.0,<1.25.0'
scipy.version: '>=1.7.0'
onnx.version: 'onnx==1.12.0'
onnx.target_opset: ''
Expand All @@ -131,7 +131,7 @@ jobs:
Py39-Onnx1110-Rt1111-Skl11:
do.bench: '0'
python.version: '3.9'
numpy.version: '>=1.21.0'
numpy.version: '>=1.21.0,<1.25.0'
scipy.version: '>=1.7.0'
onnx.version: 'onnx==1.11.0'
onnx.target_opset: ''
Expand All @@ -143,7 +143,7 @@ jobs:
Py39-Onnx1110-Rt1111-Skl10:
do.bench: '0'
python.version: '3.9'
numpy.version: '>=1.22.3'
numpy.version: '>=1.21.0,<1.25.0'
scipy.version: '>=1.7.0'
onnx.version: 'onnx==1.11.0'
onnx.target_opset: ''
Expand All @@ -155,7 +155,7 @@ jobs:
Py39-Onnx1110-Rt1100-Skl10:
do.bench: '0'
python.version: '3.9'
numpy.version: '>=1.21.0'
numpy.version: '>=1.21.0,<1.25.0'
scipy.version: '>=1.7.0'
onnx.version: 'onnx==1.11.0'
onnx.target_opset: ''
Expand All @@ -167,7 +167,7 @@ jobs:
Py39-Onnx1101-Rt190-Skl10:
do.bench: '0'
python.version: '3.9'
numpy.version: '>=1.21.0'
numpy.version: '>=1.21.0,<1.25.0'
scipy.version: '>=1.7.0'
onnx.version: 'onnx==1.10.1'
onnx.target_opset: ''
Expand Down
19 changes: 16 additions & 3 deletions docs/tutorial/plot_function_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def transform(self, X, y=None):

mapper = ColumnTransformer(
transformers=[
("c", OverpriceCalculator(), ["a", "b"]),
("ab", FunctionTransformer(), ["a", "b"]), # We keep the first column.
("c", OverpriceCalculator(), ["a", "b"]), # We add a new one.
],
remainder="passthrough",
verbose_feature_names_out=False,
Expand All @@ -101,6 +102,10 @@ def transform(self, X, y=None):
# Both pipelines return the same output.
assert_allclose(pipe.predict_proba(data), pipe_tr.predict_proba(data))

#############################
# Let's check it produces the same number of features.
assert_allclose(pipe.steps[0][-1].transform(data), pipe_tr.steps[0][-1].transform(data))

#############################
# But the conversion still fails with a different error message.

Expand Down Expand Up @@ -166,15 +171,19 @@ def overprice_converter(scope, operator, container):
# Let's check there is no discrepancies
# +++++++++++++++++++++++++++++++++++++
#
# First with :class:`onnx.reference.ReferenceEvaluator`.
# First the expected values

expected = (pipe_tr.predict(data), pipe_tr.predict_proba(data))
print(expected)

##############################
# Then let's check with :class:`onnx.reference.ReferenceEvaluator`.

feeds = {
"a": data["a"].values.reshape((-1, 1)),
"b": data["b"].values.reshape((-1, 1)),
"f": data["f"].values.reshape((-1, 1)),
}
print(feeds)

# verbose=10 to show intermediate results
ref = ReferenceEvaluator(onx, verbose=0)
Expand All @@ -191,3 +200,7 @@ def overprice_converter(scope, operator, container):

assert_allclose(expected[0], got[0])
assert_allclose(expected[1], got[1])

#######################################
# Finally.
print("done")

0 comments on commit 8d1e174

Please sign in to comment.