Skip to content

Commit

Permalink
[ONNX export] Add depth-estimation w/ DPT+GLPN (#1529)
Browse files Browse the repository at this point in the history
* Add depth-estimation w/ dpt

* Fix depth-estimation outputs

* Add GLPN onnx export
  • Loading branch information
xenova authored Nov 14, 2023
1 parent 78767f0 commit e3b7efb
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 0 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class OnnxConfig(ExportConfig, ABC):
"audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"automatic-speech-recognition": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"audio-xvector": OrderedDict({"logits": {0: "batch_size"}, "embeddings": {0: "batch_size"}}),
"depth-estimation": OrderedDict({"predicted_depth": {0: "batch_size", 1: "height", 2: "width"}}),
"document-question-answering": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}),
"fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
Expand Down
8 changes: 8 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,14 @@ class Swin2srOnnxConfig(SwinOnnxConfig):
pass


class DptOnnxConfig(ViTOnnxConfig):
pass


class GlpnOnnxConfig(ViTOnnxConfig):
pass


class PoolFormerOnnxConfig(ViTOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
ATOL_FOR_VALIDATION = 2e-3
Expand Down
11 changes: 11 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class TasksManager:
"audio-xvector": "AutoModelForAudioXVector",
"automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"),
"conversational": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
"depth-estimation": "AutoModelForDepthEstimation",
"feature-extraction": "AutoModel",
"fill-mask": "AutoModelForMaskedLM",
"image-classification": "AutoModelForImageClassification",
Expand Down Expand Up @@ -497,6 +498,11 @@ class TasksManager:
"feature-extraction",
onnx="DonutSwinOnnxConfig",
),
"dpt": supported_tasks_mapping(
"feature-extraction",
"depth-estimation",
onnx="DptOnnxConfig",
),
"electra": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down Expand Up @@ -533,6 +539,11 @@ class TasksManager:
onnx="FlaubertOnnxConfig",
tflite="FlaubertTFLiteConfig",
),
"glpn": supported_tasks_mapping(
"feature-extraction",
"depth-estimation",
onnx="GlpnOnnxConfig",
),
"gpt2": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel",
"detr": "hf-internal-testing/tiny-random-DetrModel", # hf-internal-testing/tiny-random-detr is larger
"distilbert": "hf-internal-testing/tiny-random-DistilBertModel",
"dpt": "hf-internal-testing/tiny-random-DPTModel",
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"encoder-decoder": {
"hf-internal-testing/tiny-random-EncoderDecoderModel-bert-bert": [
Expand All @@ -84,6 +85,7 @@
"fxmarty/tiny-testing-falcon-alibi": ["text-generation", "text-generation-with-past"],
},
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"glpn": "hf-internal-testing/tiny-random-GLPNModel",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
"gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel",
Expand Down

0 comments on commit e3b7efb

Please sign in to comment.