From dd02f26cb819965cbf86e16d9ce013cddc3b86af Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 1 Jul 2024 18:40:26 +0200 Subject: [PATCH] Pin eager attn in torch-ort backend (#219) --- optimum_benchmark/backends/torch_ort/backend.py | 3 +++ optimum_benchmark/backends/torch_ort/config.py | 4 +++- tests/configs/_text_encoders_.yaml | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/optimum_benchmark/backends/torch_ort/backend.py b/optimum_benchmark/backends/torch_ort/backend.py index ab13ab66..5f915001 100644 --- a/optimum_benchmark/backends/torch_ort/backend.py +++ b/optimum_benchmark/backends/torch_ort/backend.py @@ -82,6 +82,9 @@ def automodel_kwargs(self) -> Dict[str, Any]: if self.config.torch_dtype is not None: kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) + if self.config.attn_implementation is not None: + kwargs["attn_implementation"] = self.config.attn_implementation + return kwargs def train( diff --git a/optimum_benchmark/backends/torch_ort/config.py b/optimum_benchmark/backends/torch_ort/config.py index 252ee72b..adc37288 100644 --- a/optimum_benchmark/backends/torch_ort/config.py +++ b/optimum_benchmark/backends/torch_ort/config.py @@ -8,12 +8,14 @@ @dataclass class TorchORTConfig(BackendConfig): name: str = "torch-ort" - version: Optional[str] = torch_ort_version + version: Optional[str] = torch_ort_version() _target_: str = "optimum_benchmark.backends.torch_ort.backend.TorchORTBackend" # load options no_weights: bool = False torch_dtype: Optional[str] = None + # sdpa, which has became default of many architectures, fails with torch ort + attn_implementation: Optional[str] = "eager" # peft options peft_type: Optional[str] = None diff --git a/tests/configs/_text_encoders_.yaml b/tests/configs/_text_encoders_.yaml index 404cd350..9e047966 100644 --- a/tests/configs/_text_encoders_.yaml +++ b/tests/configs/_text_encoders_.yaml @@ -3,4 +3,4 @@ hydra: sweeper: params: backend.task: fill-mask,text-classification,token-classification,question-answering - backend.model: hf-internal-testing/tiny-random-bert,hf-internal-testing/tiny-random-roberta + backend.model: hf-internal-testing/tiny-random-BertModel,hf-internal-testing/tiny-random-RobertaModel