From 6001bb1dce867817b76528046e15d9f7b6a2698f Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Wed, 16 Oct 2024 16:27:16 +0200 Subject: [PATCH 1/2] Support adapters on SentenceTransformer --- sentence_transformers/SentenceTransformer.py | 3 +- sentence_transformers/peft_mixin.py | 114 +++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 sentence_transformers/peft_mixin.py diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 1a8cb2efb..ff48cc04e 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -35,6 +35,7 @@ from . import __MODEL_HUB_ORGANIZATION__, __version__ from .evaluation import SentenceEvaluator from .fit_mixin import FitMixin +from .peft_mixin import PeftAdapterMixin from .models import Normalize, Pooling, Transformer from .quantization import quantize_embeddings from .util import ( @@ -51,7 +52,7 @@ logger = logging.getLogger(__name__) -class SentenceTransformer(nn.Sequential, FitMixin): +class SentenceTransformer(nn.Sequential, FitMixin, PeftAdapterMixin): """ Loads or creates a SentenceTransformer model that can be used to map sentences / text to embeddings. diff --git a/sentence_transformers/peft_mixin.py b/sentence_transformers/peft_mixin.py new file mode 100644 index 000000000..338fd092d --- /dev/null +++ b/sentence_transformers/peft_mixin.py @@ -0,0 +1,114 @@ +import logging +from transformers.integrations.peft import PeftAdapterMixin as PeftAdapterMixinTransformers +logger = logging.getLogger(__name__) + + +class PeftAdapterMixin: + """ + Wrapper Mixin that adds the functionality to easily load and use adapters on the model. For + more details about adapters check out the documentation of PEFT + library: https://huggingface.co/docs/peft/index + + Currently supported PEFT methods follow those supported by transformers library, + you can find more information on: + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin + """ + + def load_adapter( + self, + *args, + **kwargs, + ) -> None: + """ + Load adapter weights from file or remote Hub folder." If you are not familiar with adapters and PEFT methods, we + invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft + + Requires peft as a backend to load the adapter weights. + + Args: + *args: + Positional arguments to pass to the underlying AutoModel `load_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.load_adapter + **kwargs: + Keyword arguments to pass to the underlying AutoModel `load_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.load_adapter + """ + self._modules["0"].auto_model.load_adapter(*args, **kwargs) + + def add_adapter(self, *args, **kwargs) -> None: + """ + Adds a fresh new adapter to the current model for training purpose. If no adapter name is passed, a default + name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the + default adapter name). + + Args: + *args: + Positional arguments to pass to the underlying AutoModel `add_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.add_adapter + **kwargs: + Keyword arguments to pass to the underlying AutoModel `add_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.add_adapter + + """ + self._modules["0"].auto_model.add_adapter(*args, **kwargs) + + def set_adapter(self, *args, **kwargs) -> None: + """ + Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. + + Args: + *args: + Positional arguments to pass to the underlying AutoModel `set_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.set_adapter + **kwargs: + Keyword arguments to pass to the underlying AutoModel `set_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.set_adapter + """ + self._modules["0"].auto_model.set_adapter(*args, **kwargs) + + def disable_adapters(self) -> None: + """ + Disable all adapters that are attached to the model. This leads to inferring with the base model only. + """ + self._modules["0"].auto_model.disable_adapters() + + + def enable_adapters(self) -> None: + """ + Enable adapters that are attached to the model. The model will use `self.active_adapter()` + """ + self._modules["0"].auto_model.enable_adapters() + + def active_adapters(self) -> list[str]: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters + for inference) returns the list of all active adapters so that users can deal with them accordingly. + + For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return + a single string. + """ + return self._modules["0"].auto_model.active_adapters() + + def active_adapter(self) -> str: + return self._modules["0"].auto_model.active_adapter() + + def get_adapter_state_dict(self, *args, **kwargs) -> dict: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Gets the adapter state dict that should only contain the weights tensors of the specified adapter_name adapter. + If no adapter_name is passed, the active adapter is used. + + Args: + *args: + Positional arguments to pass to the underlying AutoModel `get_adapter_state_dict` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.get_adapter_state_dict + **kwargs: + Keyword arguments to pass to the underlying AutoModel `get_adapter_state_dict` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.get_adapter_state_dict + """ + return self._modules["0"].auto_model.get_adapter_state_dict(*args, **kwargs) \ No newline at end of file From e58a7148ba6d8e3b78d8680c2c0c692a2e85046b Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Thu, 17 Oct 2024 11:14:18 +0200 Subject: [PATCH 2/2] Add transformer model check + add testing --- pyproject.toml | 2 +- sentence_transformers/SentenceTransformer.py | 2 +- sentence_transformers/peft_mixin.py | 67 +++++++++++++++----- tests/test_sentence_transformer.py | 65 +++++++++++++++++++ 4 files changed, 118 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 11550a5e7..31956969f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ train = ["datasets", "accelerate>=0.20.3"] onnx = ["optimum[onnxruntime]>=1.23.1"] onnx-gpu = ["optimum[onnxruntime-gpu]>=1.23.1"] openvino = ["optimum-intel[openvino]>=1.20.0"] -dev = ["datasets", "accelerate>=0.20.3", "pre-commit", "pytest", "pytest-cov"] +dev = ["datasets", "accelerate>=0.20.3", "pre-commit", "pytest", "pytest-cov", "peft"] [build-system] requires = ["setuptools>=42", "wheel"] diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index ff48cc04e..dbbc10679 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -35,8 +35,8 @@ from . import __MODEL_HUB_ORGANIZATION__, __version__ from .evaluation import SentenceEvaluator from .fit_mixin import FitMixin -from .peft_mixin import PeftAdapterMixin from .models import Normalize, Pooling, Transformer +from .peft_mixin import PeftAdapterMixin from .quantization import quantize_embeddings from .util import ( batch_to_device, diff --git a/sentence_transformers/peft_mixin.py b/sentence_transformers/peft_mixin.py index 338fd092d..0d0b8c83b 100644 --- a/sentence_transformers/peft_mixin.py +++ b/sentence_transformers/peft_mixin.py @@ -1,6 +1,22 @@ -import logging +from __future__ import annotations + +from functools import wraps + from transformers.integrations.peft import PeftAdapterMixin as PeftAdapterMixinTransformers -logger = logging.getLogger(__name__) + +from .models import Transformer + + +def peft_wrapper(func): + """Wrapper to call the method on the auto_model with a check for PEFT compatibility.""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + self.check_peft_compatible_model() + method = getattr(self._modules["0"].auto_model, func.__name__) + return method(*args, **kwargs) + + return wrapper class PeftAdapterMixin: @@ -9,11 +25,23 @@ class PeftAdapterMixin: more details about adapters check out the documentation of PEFT library: https://huggingface.co/docs/peft/index - Currently supported PEFT methods follow those supported by transformers library, + Currently supported PEFT methods follow those supported by transformers library, you can find more information on: https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin """ + def has_peft_compatible_model(self) -> bool: + return isinstance(self._modules["0"], Transformer) and isinstance( + self._modules["0"].auto_model, PeftAdapterMixinTransformers + ) + + def check_peft_compatible_model(self) -> None: + if not self.has_peft_compatible_model(): + raise ValueError( + "PEFT methods are only supported for Transformers models that inherit from PeftAdapterMixin" + ) + + @peft_wrapper def load_adapter( self, *args, @@ -23,7 +51,7 @@ def load_adapter( Load adapter weights from file or remote Hub folder." If you are not familiar with adapters and PEFT methods, we invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft - Requires peft as a backend to load the adapter weights. + Requires peft as a backend to load the adapter weights and the underlying model to be compatible with PEFT. Args: *args: @@ -33,14 +61,17 @@ def load_adapter( Keyword arguments to pass to the underlying AutoModel `load_adapter` function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.load_adapter """ - self._modules["0"].auto_model.load_adapter(*args, **kwargs) + ... # Implementation handled by the wrapper + @peft_wrapper def add_adapter(self, *args, **kwargs) -> None: """ - Adds a fresh new adapter to the current model for training purpose. If no adapter name is passed, a default + Adds a fresh new adapter to the current model for training purposes. If no adapter name is passed, a default name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the default adapter name). + Requires peft as a backend to load the adapter weights and the underlying model to be compatible with PEFT. + Args: *args: Positional arguments to pass to the underlying AutoModel `add_adapter` function. More information can be found in the transformers documentation @@ -48,10 +79,11 @@ def add_adapter(self, *args, **kwargs) -> None: **kwargs: Keyword arguments to pass to the underlying AutoModel `add_adapter` function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.add_adapter - + """ - self._modules["0"].auto_model.add_adapter(*args, **kwargs) + ... # Implementation handled by the wrapper + @peft_wrapper def set_adapter(self, *args, **kwargs) -> None: """ Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. @@ -64,21 +96,23 @@ def set_adapter(self, *args, **kwargs) -> None: Keyword arguments to pass to the underlying AutoModel `set_adapter` function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.set_adapter """ - self._modules["0"].auto_model.set_adapter(*args, **kwargs) + ... # Implementation handled by the wrapper + @peft_wrapper def disable_adapters(self) -> None: """ Disable all adapters that are attached to the model. This leads to inferring with the base model only. """ - self._modules["0"].auto_model.disable_adapters() - + ... # Implementation handled by the wrapper + @peft_wrapper def enable_adapters(self) -> None: """ Enable adapters that are attached to the model. The model will use `self.active_adapter()` """ - self._modules["0"].auto_model.enable_adapters() + ... # Implementation handled by the wrapper + @peft_wrapper def active_adapters(self) -> list[str]: """ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT @@ -90,11 +124,12 @@ def active_adapters(self) -> list[str]: For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return a single string. """ - return self._modules["0"].auto_model.active_adapters() + ... # Implementation handled by the wrapper - def active_adapter(self) -> str: - return self._modules["0"].auto_model.active_adapter() + @peft_wrapper + def active_adapter(self) -> str: ... # Implementation handled by the wrapper + @peft_wrapper def get_adapter_state_dict(self, *args, **kwargs) -> dict: """ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT @@ -111,4 +146,4 @@ def get_adapter_state_dict(self, *args, **kwargs) -> dict: Keyword arguments to pass to the underlying AutoModel `get_adapter_state_dict` function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.get_adapter_state_dict """ - return self._modules["0"].auto_model.get_adapter_state_dict(*args, **kwargs) \ No newline at end of file + ... # Implementation handled by the wrapper diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 9d8e9c347..d699ecfdc 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -687,3 +687,68 @@ def test_empty_encode(stsb_bert_tiny_model: SentenceTransformer) -> None: model = stsb_bert_tiny_model embeddings = model.encode([]) assert embeddings.shape == (0,) + + +def test_multiple_adapters() -> None: + text = "Hello, World!" + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") + vec_initial = model.encode(text) + from peft import LoraConfig, TaskType, get_model_status + + # Adding a fresh adapter + peft_config = LoraConfig( + target_modules=["query", "key", "value"], + task_type=TaskType.FEATURE_EXTRACTION, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + init_lora_weights=False, # Random initialization to test the adapter + ) + model.add_adapter(peft_config) + + # Load an adapter from the hub + # TODO: Upload the adapter to the hub to test the loading (confirmed to be working with manual testing) + # model.load_adapter("sentence-transformers-testing/stsb-bert-tiny-lora", "hub_adapter") + + # Adding another one with a different name + peft_config = LoraConfig( + target_modules=["value"], + task_type=TaskType.FEATURE_EXTRACTION, + inference_mode=False, + r=2, + lora_alpha=16, + lora_dropout=0.1, + init_lora_weights=False, # Random initialization to test the adapter + ) + model.add_adapter(peft_config, "my_adapter") + + # Check that peft recognizes the adapters while we compute vectors for later comparison + status = get_model_status(model) + assert status.available_adapters == ["default", "my_adapter"] # ["default", "my_adapter", "hub_adapter"] + assert status.enabled + assert status.active_adapters == ["my_adapter"] + assert status.active_adapters == model.active_adapters() + vec_my_adapter = model.encode(text) + + model.set_adapter("default") + status = get_model_status(model) + assert status.active_adapters == ["default"] + vec_default_adapter = model.encode(text) + + model.disable_adapters() + status = get_model_status(model) + assert not status.enabled + vec_no_adapter = model.encode(text) + + # Check that each vector is different + assert not np.allclose(vec_my_adapter, vec_default_adapter) + assert not np.allclose(vec_my_adapter, vec_no_adapter) + assert not np.allclose(vec_default_adapter, vec_no_adapter) + # Check that the vectors from the original model match + assert np.allclose(vec_initial, vec_no_adapter) + + # Check that for non Transformer-based models we have an error + model = SentenceTransformer("sentence-transformers/average_word_embeddings_levy_dependency") + with pytest.raises(ValueError, match="PEFT methods are only supported"): + model.add_adapter(peft_config)