Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected behavior for float64 with InferenceSession sklearn #1074

Open
maulberto3 opened this issue Feb 27, 2024 · 1 comment
Open

Unexpected behavior for float64 with InferenceSession sklearn #1074

maulberto3 opened this issue Feb 27, 2024 · 1 comment

Comments

@maulberto3
Copy link

The following code won't work, onnx will complain that is expecting float32 when everything was float64: data, model and convert [ONNXRuntimeError] : 1 : FAIL : Load model from rf_iris.onnx failed:Type Error: Type (tensor(double)) of output arg (variable) of node (TreeEnsembleRegressor) does not match expected type (tensor(float)). .

Is this ok? Or maybe I am missing something?

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

from skl2onnx import convert_sklearn, to_onnx
from skl2onnx.common.data_types import DoubleTensorType

import onnxruntime as rt


# Train
X, y = make_regression(100, 10, bias=1, noise=1, random_state=16)  # float64
X.dtype, y.dtype
X_train, X_test, y_train, y_test = train_test_split(X, y)
rf = RandomForestRegressor(n_jobs=-1, random_state=16)
_ = rf.fit(X_train, y_train)
rf.predict(X_test).dtype  # float64

# Convert
init_types = [('double_input', DoubleTensorType([None, 10]))]
onx = convert_sklearn(rf, initial_types=init_types)
with open("rf_iris.onnx", "wb") as f:
    _ = f.write(onx.SerializeToString())

# Call
sess = rt.InferenceSession("rf_iris.onnx", )
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
sess.run(None, {input_name: X_test})[0]
sess.run([label_name], {input_name: X_test})[0]
@xadupre
Copy link
Collaborator

xadupre commented Apr 4, 2024

Some versions of onnxruntime only supports floats for TreeEnsemble. Which version are you using? If it is the latest, 1.17.1, this should have been fixed. But the maybe the converting library was not updated to support it. I'll check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: To do
Development

No branches or pull requests

2 participants