diff --git a/docs/source/exporters/onnx/package_reference/configuration.mdx b/docs/source/exporters/onnx/package_reference/configuration.mdx index 4443a4fb18..f17b66701f 100644 --- a/docs/source/exporters/onnx/package_reference/configuration.mdx +++ b/docs/source/exporters/onnx/package_reference/configuration.mdx @@ -117,6 +117,7 @@ They specify which input generators should be used for the dummy inputs, but rem - Segformer - SEW - Speech2Text +- Splinter - SqueezeBert - Stable Diffusion - T5 diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 6ce8f62af3..f3cc84d0dc 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -107,6 +107,10 @@ class XLMOnnxConfig(BertOnnxConfig): pass +class SplinterOnnxConfig(BertOnnxConfig): + pass + + class DistilBertOnnxConfig(BertOnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 95531f7c40..a8574c58bf 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -600,6 +600,11 @@ class TasksManager: "speech2seq-lm-with-past", onnx="Speech2TextOnnxConfig", ), + "splinter": supported_tasks_mapping( + "default", + "question-answering", + onnx="SplinterOnnxConfig", + ), "squeezebert": supported_tasks_mapping( "default", "masked-lm", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 900778bf9f..2edd5f09f7 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -198,6 +198,7 @@ class NormalizedConfigManager: "poolformer": NormalizedVisionConfig, "resnet": NormalizedVisionConfig, "roberta": NormalizedTextConfig, + "splinter": NormalizedTextConfig, "t5": T5LikeNormalizedTextConfig, "whisper": WhisperLikeNormalizedTextConfig, "xlm-roberta": NormalizedTextConfig, diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 7c6ea9d7df..ec698fd790 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -83,6 +83,7 @@ "roberta": "hf-internal-testing/tiny-random-RobertaModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel", "segformer": "hf-internal-testing/tiny-random-SegformerModel", + "splinter": "hf-internal-testing/tiny-random-SplinterModel", "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", "swin": "hf-internal-testing/tiny-random-SwinModel", "t5": "hf-internal-testing/tiny-random-t5", @@ -175,6 +176,7 @@ "roberta": "roberta-base", "roformer": "junnyu/roformer_chinese_base", "segformer": "nvidia/segformer-b0-finetuned-ade-512-512", + "splinter": "hf-internal-testing/tiny-random-SplinterModel", "squeezebert": "squeezebert/squeezebert-uncased", "swin": "microsoft/swin-tiny-patch4-window7-224", "t5": "t5-small",