From fe71b3c41ac2d3c953aeda73649da263f97aefdd Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Thu, 9 Jan 2025 13:31:52 +0100 Subject: [PATCH] Fix emb model export and load with trfrs (#756) # What does this PR do? Fixes #744 With the PR, we should be once again able to export embedding model via transformers library or sentence transformer library depending on the class called: * With Transformers ```python import torch from optimum.neuron import NeuronModelForFeatureExtraction from transformers import AutoConfig, AutoTokenizer compiler_args = {"auto_cast": "matmul", "auto_cast_type": "fp16"} input_shapes = {"batch_size": 4, "sequence_length": 512} model = NeuronModelForFeatureExtraction.from_pretrained( model_id="TaylorAI/bge-micro-v2", # BERT SMALL export=True, disable_neuron_cache=True, **compiler_args, **input_shapes, ) ``` * With Sentence Transformers ```python import torch from optimum.neuron import NeuronModelForSentenceTransformers from transformers import AutoConfig, AutoTokenizer compiler_args = {"auto_cast": "matmul", "auto_cast_type": "fp16"} input_shapes = {"batch_size": 4, "sequence_length": 512} model = NeuronModelForSentenceTransformers.from_pretrained( model_id="TaylorAI/bge-micro-v2", # BERT SMALL export=True, disable_neuron_cache=True, **compiler_args, **input_shapes, ) ``` --- optimum/neuron/modeling_traced.py | 1 + tests/inference/test_modeling.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/neuron/modeling_traced.py b/optimum/neuron/modeling_traced.py index 6d37237c6..ef92da0b0 100644 --- a/optimum/neuron/modeling_traced.py +++ b/optimum/neuron/modeling_traced.py @@ -363,6 +363,7 @@ def _export( local_files_only=local_files_only, token=token, do_validation=False, + library_name=cls.library_name, **kwargs_shapes, ) config = AutoConfig.from_pretrained(save_dir_path) diff --git a/tests/inference/test_modeling.py b/tests/inference/test_modeling.py index 51c873db0..dd6d3ffc7 100644 --- a/tests/inference/test_modeling.py +++ b/tests/inference/test_modeling.py @@ -109,7 +109,6 @@ def test_load_model_from_hub_subfolder(self): self.TINY_SUBFOLDER_MODEL_ID, subfolder="my_subfolder", export=True, - library_name="transformers", **self.STATIC_INPUTS_SHAPES, ) self.assertIsInstance(model.model, torch.jit._script.ScriptModule)