Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generating embeddings with ONNX Runtime leads to errors #2983

Open
aoezdTchibo opened this issue Oct 13, 2024 · 2 comments
Open

Generating embeddings with ONNX Runtime leads to errors #2983

aoezdTchibo opened this issue Oct 13, 2024 · 2 comments

Comments

@aoezdTchibo
Copy link

aoezdTchibo commented Oct 13, 2024

With the new release of version 3.2.0, the use of ONNX has become much easier but initial local tests led to various errors, meaning that it was not possible to use ONNX Runtime via Sentence Transformers. See these two examples:

  1. intfloat/multilingual-e5-small
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("intfloat/multilingual-e5-small", backend="onnx")

sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)

Lead to following error:

AttributeError                            Traceback (most recent call last)
Cell In[20], line 6
      3 model = SentenceTransformer("intfloat/multilingual-e5-small", backend="onnx")
      5 sentences = ["This is an example sentence", "Each sentence is converted"]
----> 6 embeddings = model.encode(sentences)

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py:621, in SentenceTransformer.encode(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings, **kwargs)
    618 features.update(extra_features)
    620 with torch.no_grad():
--> 621     out_features = self.forward(features, **kwargs)
    622     if self.device.type == "hpu":
    623         out_features = copy.deepcopy(out_features)

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py:688, in SentenceTransformer.forward(self, input, **kwargs)
    686     module_kwarg_keys = self.module_kwargs.get(module_name, [])
    687     module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys}
--> 688     input = module(input, **module_kwargs)
    689 return input

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/sentence_transformers/models/Transformer.py:350, in Transformer.forward(self, features, **kwargs)
    347 if "token_type_ids" in features:
    348     trans_features["token_type_ids"] = features["token_type_ids"]
--> 350 output_states = self.auto_model(**trans_features, **kwargs, return_dict=False)
    351 output_tokens = output_states[0]
    353 features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/optimum/modeling_base.py:98, in OptimizedModel.__call__(self, *args, **kwargs)
     97 def __call__(self, *args, **kwargs):
---> 98     return self.forward(*args, **kwargs)

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/optimum/onnxruntime/modeling_ort.py:1106, in ORTModelForFeatureExtraction.forward(self, input_ids, attention_mask, token_type_ids, **kwargs)
   1103 else:
   1104     model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}
-> 1106     onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
   1107     onnx_outputs = self.model.run(None, onnx_inputs)
   1108     model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/optimum/onnxruntime/modeling_ort.py:940, in ORTModel._prepare_onnx_inputs(self, use_torch, **inputs)
    937 onnx_inputs[input_name] = inputs.pop(input_name)
    939 if use_torch:
--> 940     onnx_inputs[input_name] = onnx_inputs[input_name].numpy(force=True)
    942 if onnx_inputs[input_name].dtype != self.input_dtypes[input_name]:
    943     onnx_inputs[input_name] = onnx_inputs[input_name].astype(
    944         TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name])
    945     )

AttributeError: 'NoneType' object has no attribute 'numpy'
  1. sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", backend="onnx")

sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)

Lead to following error:

Fail                                      Traceback (most recent call last)
Cell In[21], line 6
      3 model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", backend="onnx")
      5 sentences = ["This is an example sentence", "Each sentence is converted"]
----> 6 embeddings = model.encode(sentences)

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py:621, in SentenceTransformer.encode(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings, **kwargs)
    618 features.update(extra_features)
    620 with torch.no_grad():
--> 621     out_features = self.forward(features, **kwargs)
    622     if self.device.type == "hpu":
    623         out_features = copy.deepcopy(out_features)

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py:688, in SentenceTransformer.forward(self, input, **kwargs)
    686     module_kwarg_keys = self.module_kwargs.get(module_name, [])
    687     module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys}
--> 688     input = module(input, **module_kwargs)
    689 return input

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/sentence_transformers/models/Transformer.py:350, in Transformer.forward(self, features, **kwargs)
    347 if "token_type_ids" in features:
    348     trans_features["token_type_ids"] = features["token_type_ids"]
--> 350 output_states = self.auto_model(**trans_features, **kwargs, return_dict=False)
    351 output_tokens = output_states[0]
    353 features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/optimum/modeling_base.py:98, in OptimizedModel.__call__(self, *args, **kwargs)
     97 def __call__(self, *args, **kwargs):
---> 98     return self.forward(*args, **kwargs)

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/optimum/onnxruntime/modeling_ort.py:1107, in ORTModelForFeatureExtraction.forward(self, input_ids, attention_mask, token_type_ids, **kwargs)
   1104 model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}
   1106 onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
-> 1107 onnx_outputs = self.model.run(None, onnx_inputs)
   1108 model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
   1110 if "last_hidden_state" in self.output_names:

File ~/PycharmProjects/product-search-custom-embedding/.venv/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:220, in Session.run(self, output_names, input_feed, run_options)
    218     output_names = [output.name for output in self._outputs_meta]
    219 try:
--> 220     return self._sess.run(output_names, input_feed, run_options)
    221 except C.EPFail as err:
    222     if self._enable_fallback:

Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running CoreML_13584095612085833210_4 node. Name:'CoreMLExecutionProvider_CoreML_13584095612085833210_4_4' Status Message: Error executing model: Unable to compute the prediction using a neural network model. It can be an invalid input data or broken/unsupported model (error code: -1).

Local environment:
python=3.10
sentence-transformers=3.2.0
onnx=1.17.0
onnxruntime=1.19.2
optimum=1.23.0

@tomaarsen
Copy link
Collaborator

tomaarsen commented Oct 15, 2024

Hello!

The former is an issue with Optimum it seems. I've reported it here: huggingface/optimum#2062
In short, the token_type_ids are not returned by the tokenizer, as it's optional in transformers. But for BERT models, it's mandatory in optimum.

The second one I'm not familiar with. I also can't reproduce that one, but it seems that you're using the CoreMLExecutionProvider (by default), an execution provider that I'm not familiar with. Could you perhaps try it with:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", backend="onnx", model_kwargs={"provider": "CPUExecutionProvider"})

sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)
  • Tom Aarsen

@loretoparisi
Copy link

I'm adding testing of the top 10 SBERT multi-lingual models from MTEB, several issues came out:

# top 10 MTEB leaderboard multilingual SBERT embeddings models
models = [
'BAAI/bge-multilingual-gemma2',
'intfloat/multilingual-e5-large-instruct',
'HIT-TMG/KaLM-embedding-multilingual-mini-v1',
'gte-multilingual-base',
'Alibaba-NLP/gte-multilingual-base',
'intfloat/multilingual-e5-base',
'intfloat/multilingual-e5-small'
]
for model_name in models:
    try:
        model = SentenceTransformer (model_name, backend="onnx", 
                                     model_kwargs={
                                         "provider": "CPUExecutionProvider",
                                         # not supported with onnx
                                         #"torch_dtype": torch.float16
                                     },
                                     trust_remote_code=True,
                                     cache_folder='/mnt/datasets/sbert')
        '''
            SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: ORTModelForFeatureExtraction 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
        '''
        print(model)
        sentences = ["This is an example sentence", "Each sentence is converted" ]
        embeddings = model.encode (sentences)
        print( embeddings.shape() )
    except Exception as e:
        print(f'error loading {model_name} {str(e)}')

Stacktrace

  • BAAI/bge-multilingual-gemma2
No 'model.onnx' found in 'BAAI/bge-multilingual-gemma2'. Exporting the model to ONNX.
Loading checkpoint shards: 100%
error loading BAAI/bge-multilingual-gemma2 Trying to export a gemma2 model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type gemma2 to be supported natively in the ONNX export.
  • intfloat/multilingual-e5-large-instruct
error loading intfloat/multilingual-e5-large-instruct 'tuple' object is not callable
No 'model.onnx' found in 'HIT-TMG/KaLM-embedding-multilingual-mini-v1'. Exporting the model to ONNX.
tokenization_qwen.py: 100%|
A new version of the following files was downloaded from https://huggingface.co/HIT-TMG/KaLM-embedding-multilingual-mini-v1:
- tokenization_qwen.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/HIT-TMG/KaLM-embedding-multilingual-mini-v1:
- tokenization_qwen.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
/opt/conda/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:103: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if sequence_length != 1:
Saving the exported ONNX model is heavily recommended to avoid having to export it again. Do so with `model.push_to_hub('HIT-TMG/KaLM-embedding-multilingual-mini-v1', create_pr=True)`.
tokenization_qwen.py: 100%
A new version of the following files was downloaded from https://huggingface.co/HIT-TMG/KaLM-embedding-multilingual-mini-v1:
- tokenization_qwen.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
  • HIT-TMG/KaLM-embedding-multilingual-mini-v1
error loading HIT-TMG/KaLM-embedding-multilingual-mini-v1 'position_ids'
No sentence-transformers model found with name sentence-transformers/gte-multilingual-base. Creating a new one with mean pooling.
error loading gte-multilingual-base sentence-transformers/gte-multilingual-base is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`
A new version of the following files was downloaded from https://huggingface.co/Alibaba-NLP/new-impl:
- configuration.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
No 'model.onnx' found in 'Alibaba-NLP/gte-multilingual-base'. Exporting the model to ONNX.

A new version of the following files was downloaded from https://huggingface.co/Alibaba-NLP/new-impl:
- modeling.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.

Some weights of the model checkpoint at Alibaba-NLP/gte-multilingual-base were not used when initializing NewModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing NewModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing NewModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
tokenizer_config.json: 100%|
A new version of the following files was downloaded from https://huggingface.co/Alibaba-NLP/new-impl:
- configuration.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
error loading Alibaba-NLP/gte-multilingual-base Trying to export a new model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type new to be supported natively in the ONNX export.
  • intfloat/multilingual-e5-base
error loading intfloat/multilingual-e5-base 'tuple' object is not callable

- intfloat/multilingual-e5-small

error loading intfloat/multilingual-e5-small 'NoneType' object has no attribute 'numpy'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants