Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No ONNX support for BERT models when token_type_ids is not provided #2062

Open
2 of 4 tasks
tomaarsen opened this issue Oct 15, 2024 · 1 comment
Open
2 of 4 tasks
Assignees
Labels
bug Something isn't working

Comments

@tomaarsen
Copy link
Member

tomaarsen commented Oct 15, 2024

System Info

optimum==1.23.1
transformers==4.43.4
onnxruntime-gpu==1.19.2
sentence-transformers==3.2.0

Windows
Python 3.11.6

Who can help?

@michaelbenayoun

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoModel, AutoTokenizer

model_id = "intfloat/multilingual-e5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)

provider = "CPUExecutionProvider"
# provider = "CUDAExecutionProvider"
onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=True, provider=provider)

inputs = tokenizer("This is my test sentence", return_tensors="pt")
print(inputs.keys())
# => dict_keys(['input_ids', 'attention_mask'])

outputs = model(**inputs)
print(outputs[0].shape)
# => torch.Size([1, 7, 384])
onnx_outputs = onnx_model(**inputs)
print(onnx_outputs[0].shape)
# If CPUExecutionProvider => AttributeError: 'NoneType' object has no attribute 'numpy'
# If CUDAExecutionProvider => KeyError: 'token_type_ids'

Expected behavior

I would expect for optimum to mirror the transformers behaviour where token_type_ids is set to torch.zeros(input_ids.shape, ...) if it's not explicitly provided.
See here for that implementation in transformers: https://github.com/huggingface/transformers/blob/4de1bdbf637fe6411c104c62ab385f660bfb1064/src/transformers/models/bert/modeling_bert.py#L1070-L1076

This is preventing the following:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("intfloat/multilingual-e5-small", backend="onnx")

sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)

See also UKPLab/sentence-transformers#2983

  • Tom Aarsen
@echarlaix
Copy link
Collaborator

Thanks for reporting @tomaarsen! This is something that we are doing for openvino models https://github.com/huggingface/optimum-intel/blob/f7b5b547c167cb6a9f20fa77d493ee2dde3c3034/optimum/intel/openvino/modeling.py#L395, but never added for onnx models, will take care of adding it!

@echarlaix echarlaix self-assigned this Oct 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants