From 1fd0b994f841e6e33d914572a826103b7e752c89 Mon Sep 17 00:00:00 2001
From: Lu Peng <bolu.peng@oracle.com>
Date: Wed, 20 Nov 2024 14:12:06 -0500
Subject: [PATCH 1/7] Initial commit.

---
 .../extractor/embedding_onnx_extractor.py     |  55 +++++
 ads/model/framework/embedding_onnx_model.py   |  41 ++++
 ads/templates/score_embedding_onnx.jinja2     | 190 ++++++++++++++++++
 3 files changed, 286 insertions(+)
 create mode 100644 ads/model/extractor/embedding_onnx_extractor.py
 create mode 100644 ads/model/framework/embedding_onnx_model.py
 create mode 100644 ads/templates/score_embedding_onnx.jinja2

diff --git a/ads/model/extractor/embedding_onnx_extractor.py b/ads/model/extractor/embedding_onnx_extractor.py
new file mode 100644
index 000000000..becfef0a7
--- /dev/null
+++ b/ads/model/extractor/embedding_onnx_extractor.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2024 Oracle and/or its affiliates.
+# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
+
+from ads.model.extractor.model_info_extractor import ModelInfoExtractor
+
+
+class EmbeddingONNXExtractor(ModelInfoExtractor):
+    def __init__(self, model):
+        self.model = model
+
+    @property
+    def framework(self):
+        """Extracts the framework of the model.
+
+        Returns
+        ----------
+        str:
+           The framework of the model.
+        """
+        pass
+
+    @property
+    def algorithm(self):
+        """Extracts the algorithm of the model.
+
+        Returns
+        ----------
+        object:
+           The algorithm of the model.
+        """
+        pass
+
+    @property
+    def version(self):
+        """Extracts the framework version of the model.
+
+        Returns
+        ----------
+        str:
+           The framework version of the model.
+        """
+        pass
+
+    @property
+    def hyperparameter(self):
+        """Extracts the hyperparameters of the model.
+
+        Returns
+        ----------
+        dict:
+           The hyperparameters of the model.
+        """
+        pass
diff --git a/ads/model/framework/embedding_onnx_model.py b/ads/model/framework/embedding_onnx_model.py
new file mode 100644
index 000000000..440733f51
--- /dev/null
+++ b/ads/model/framework/embedding_onnx_model.py
@@ -0,0 +1,41 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2024 Oracle and/or its affiliates.
+# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
+
+from typing import Any, Callable, Dict, Self
+
+from ads.model.extractor.embedding_onnx_extractor import EmbeddingONNXExtractor
+from ads.model.generic_model import FrameworkSpecificModel
+from ads.model.model_properties import ModelProperties
+from ads.model.serde.common import SERDE
+
+
+class EmbeddingONNXModel(FrameworkSpecificModel):
+    def __init__(
+        self,
+        estimator: Callable[..., Any] = None,
+        artifact_dir: str | None = None,
+        properties: ModelProperties | None = None,
+        auth: Dict | None = None,
+        serialize: bool = True,
+        model_save_serializer: SERDE | None = None,
+        model_input_serializer: SERDE | None = None,
+        **kwargs: dict,
+    ) -> Self:
+        super().__init__(
+            estimator,
+            artifact_dir,
+            properties,
+            auth,
+            serialize,
+            model_save_serializer,
+            model_input_serializer,
+            **kwargs,
+        )
+
+        self._extractor = EmbeddingONNXExtractor(estimator)
+        self.framework = self._extractor.framework
+        self.algorithm = self._extractor.algorithm
+        self.version = self._extractor.version
+        self.hyperparameter = self._extractor.hyperparameter
diff --git a/ads/templates/score_embedding_onnx.jinja2 b/ads/templates/score_embedding_onnx.jinja2
new file mode 100644
index 000000000..8a830f073
--- /dev/null
+++ b/ads/templates/score_embedding_onnx.jinja2
@@ -0,0 +1,190 @@
+# score.py 1.0 generated by ADS 2.11.10 on 20241002_212041
+import os
+import sys
+import json
+from functools import lru_cache
+import onnxruntime as ort
+import jsonschema
+from jsonschema import validate, ValidationError
+from transformers import AutoTokenizer
+import logging
+
+model_name = 'model.onnx'
+openapi_schema = ''
+
+
+"""
+   Inference script. This script is used for prediction by scoring server when schema is known.
+"""
+
+
+@lru_cache(maxsize=10)
+def load_model(model_file_name=model_name):
+    """
+    Loads model from the serialized format
+
+    Returns
+    -------
+    model:  a model instance on which predict API can be invoked
+    """
+    model_dir = os.path.dirname(os.path.realpath(__file__))
+    if model_dir not in sys.path:
+        sys.path.insert(0, model_dir)
+    contents = os.listdir(model_dir)
+    if model_file_name in contents:
+        # print(f'Start loading {model_file_name} from model directory {model_dir} ...')
+        model = ort.InferenceSession(os.path.join(model_dir, model_file_name), providers=['CUDAExecutionProvider','CPUExecutionProvider'])
+        # print("Model is successfully loaded.")
+        return model
+    else:
+        raise Exception(f'{model_file_name} is not found in model directory {model_dir}')
+
+
+@lru_cache(maxsize=1)
+def load_tokenizer(model_full_name):
+
+    # todo: do we need model_full_name or have configs in artifact dir?
+    model_dir = os.path.dirname(os.path.realpath(__file__))
+    # initialize tokenizer
+    return AutoTokenizer.from_pretrained(model_dir, clean_up_tokenization_spaces=True)
+
+@lru_cache(maxsize=1)
+def load_openapi_schema():
+    """
+    Loads the input schema for the incoming request
+
+    Returns
+    -------
+    schema:  openapi schema as json
+    """
+    model_dir = os.path.dirname(os.path.realpath(__file__))
+    if model_dir not in sys.path:
+        sys.path.insert(0, model_dir)
+    contents = os.listdir(model_dir)
+
+    try:
+        with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), openapi_schema), 'r') as file:
+            return json.load(file)
+    except:
+        raise Exception(f'{openapi_schema} is not found in model directory {model_dir}')
+
+
+def validate_inputs(data):
+
+    api_schema = load_openapi_schema()
+
+    # use a reference resolver for internal $refs
+    resolver = jsonschema.RefResolver.from_schema(api_schema)
+
+    # get the actual schema part to validate against
+    request_schema = api_schema["components"]["schemas"]["OpenAICompatRequest"]
+
+    try:
+        # validate the input JSON
+        validate(instance=data, schema=request_schema, resolver=resolver)
+    except ValidationError as e:
+        # todo: add custom error code and message in error handler
+        example_value = {
+            "input": ["What are activation functions?"],
+            "encoding_format": "float",
+            "model": "sentence-transformers/all-MiniLM-L6-v2",
+            "user": "user"
+        }
+        message = f"JSON is invalid. Error: {e.message}\n An example of the expected format for 'OpenAICompatRequest' looks like: \n {json.dumps(example_value, indent=2)}"
+        raise ValueError(message) from e
+
+
+def pre_inference(data):
+    """
+    Preprocess data
+
+    Parameters
+    ----------
+    data: Data format as expected by the predict API.
+
+    Returns
+    -------
+    onnx_inputs: Data format after any processing
+    total_tokens: total tokens that will be processed by the model
+
+    """
+    validate_inputs(data)
+
+    tokenizer = load_tokenizer(data['model'])
+    inputs = tokenizer(data['input'], return_tensors="np", padding=True)
+
+    padding_token_id = tokenizer.pad_token_id
+    total_tokens = (inputs["input_ids"] != padding_token_id).sum().item()
+    onnx_inputs = {key: [l.tolist()for l in inputs[key] ] for key in inputs}
+
+    return onnx_inputs, total_tokens
+
+def convert_embeddings_to_openapi_format(embeddings, model_name, total_tokens):
+
+    formatted_data = []
+    openai_compat_response = {}
+    for idx, embedding in enumerate(embeddings):
+
+        formatted_embedding = {
+            "object": "embedding",
+            "embedding": embedding,
+            "index": idx
+        }
+        formatted_data.append(formatted_embedding)
+
+    # create the final OpenAICompatResponse format
+    openai_compat_response = {
+        "object": "list",
+        "data": formatted_data,
+        "model": model_name,  # Use the provided model name
+        "usage": {
+            "prompt_tokens": total_tokens,  # represents the token count for just the text input
+            "total_tokens": total_tokens     # total number of tokens involved in the request, same in case of embeddings
+        }
+    }
+
+    return openai_compat_response
+
+
+def post_inference(outputs, model_name, total_tokens):
+    """
+    Post-process the model results
+
+    Parameters
+    ----------
+    outputs: Data format after calling model.run
+    model_name: name of model
+    total_tokens: total tokens that will be processed by the model
+
+    Returns
+    -------
+    outputs: Data format after any processing.
+
+    """
+    results = [embed.tolist() for embed in outputs]
+    response = convert_embeddings_to_openapi_format(results, model_name, total_tokens)
+    return response
+
+def predict(data, model=load_model()):
+    """
+    Returns prediction given the model and data to predict
+
+    Parameters
+    ----------
+    model: Model instance returned by load_model API.
+    data: Data format as expected by the predict API of the core estimator. For eg. in case of sckit models it could be numpy array/List of list/Pandas DataFrame.
+
+    Returns
+    -------
+    predictions: Output from scoring server
+        Format: {'prediction': output from model.predict method}
+
+    """
+    # inputs contains 'input_ids', 'token_type_ids', 'attention_mask' but 'token_type_ids' is optional
+    inputs, total_tokens = pre_inference(data)
+
+    onnx_inputs = [inp.name for inp in model.get_inputs()]
+    embeddings = model.run(None, {key: inputs[key] if key in inputs else None for key in onnx_inputs})[0]
+
+    response = post_inference(embeddings, data['model'], total_tokens)
+    return response

From 4681ead30bbc4074fa33e6e64ad86f66bf243866 Mon Sep 17 00:00:00 2001
From: Lu Peng <bolu.peng@oracle.com>
Date: Mon, 16 Dec 2024 13:29:28 -0500
Subject: [PATCH 2/7] Added support for embedding onnx model.

---
 ads/model/__init__.py                         |   24 +-
 ads/model/artifact.py                         |   55 +-
 .../extractor/embedding_onnx_extractor.py     |   35 +-
 ads/model/framework/embedding_onnx_model.py   |  331 +++-
 ads/model/generic_model.py                    |   50 +-
 ads/model/model_metadata.py                   |   15 +-
 ads/templates/schemas/openapi.json            | 1740 +++++++++++++++++
 ads/templates/score_embedding_onnx.jinja2     |   12 +-
 ...st_model_framework_embedding_onnx_model.py |  132 ++
 9 files changed, 2314 insertions(+), 80 deletions(-)
 create mode 100644 ads/templates/schemas/openapi.json
 create mode 100644 tests/unitary/with_extras/model/test_model_framework_embedding_onnx_model.py

diff --git a/ads/model/__init__.py b/ads/model/__init__.py
index 65895eda9..f0b0febae 100644
--- a/ads/model/__init__.py
+++ b/ads/model/__init__.py
@@ -1,29 +1,26 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*--
 
-# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
+# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
 
-from ads.model.generic_model import GenericModel, ModelState
 from ads.model.datascience_model import DataScienceModel
-from ads.model.model_properties import ModelProperties
+from ads.model.deployment.model_deployer import ModelDeployer
+from ads.model.deployment.model_deployment import ModelDeployment
+from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
 from ads.model.framework.automl_model import AutoMLModel
+from ads.model.framework.embedding_onnx_model import EmbeddingONNXModel
+from ads.model.framework.huggingface_model import HuggingFacePipelineModel
 from ads.model.framework.lightgbm_model import LightGBMModel
 from ads.model.framework.pytorch_model import PyTorchModel
 from ads.model.framework.sklearn_model import SklearnModel
+from ads.model.framework.spark_model import SparkPipelineModel
 from ads.model.framework.tensorflow_model import TensorFlowModel
 from ads.model.framework.xgboost_model import XGBoostModel
-from ads.model.framework.spark_model import SparkPipelineModel
-from ads.model.framework.huggingface_model import HuggingFacePipelineModel
-
-from ads.model.deployment.model_deployer import ModelDeployer
-from ads.model.deployment.model_deployment import ModelDeployment
-from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
-
+from ads.model.generic_model import GenericModel, ModelState
+from ads.model.model_properties import ModelProperties
+from ads.model.model_version_set import ModelVersionSet, experiment
 from ads.model.serde.common import SERDE
 from ads.model.serde.model_input import ModelInputSerializer
-
-from ads.model.model_version_set import ModelVersionSet, experiment
 from ads.model.service.oci_datascience_model_version_set import (
     ModelVersionSetNotExists,
     ModelVersionSetNotSaved,
@@ -42,6 +39,7 @@
     "XGBoostModel",
     "SparkPipelineModel",
     "HuggingFacePipelineModel",
+    "EmbeddingONNXModel",
     "ModelDeployer",
     "ModelDeployment",
     "ModelDeploymentProperties",
diff --git a/ads/model/artifact.py b/ads/model/artifact.py
index 6659247a5..4c116eb78 100644
--- a/ads/model/artifact.py
+++ b/ads/model/artifact.py
@@ -1,28 +1,28 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*--
 
-# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
+# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
 
 import fnmatch
 import importlib
 import os
-import sys
 import shutil
+import sys
 import tempfile
 import uuid
-import fsspec
+from datetime import datetime
 from typing import Dict, Optional, Tuple
+
+import fsspec
+from jinja2 import Environment, PackageLoader
+
+from ads import __version__
 from ads.common import auth as authutil
 from ads.common import logger, utils
 from ads.common.object_storage_details import ObjectStorageDetails
 from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
 from ads.model.runtime.env_info import EnvInfo, InferenceEnvInfo, TrainingEnvInfo
 from ads.model.runtime.runtime_info import RuntimeInfo
-from jinja2 import Environment, PackageLoader
-import warnings
-from ads import __version__
-from datetime import datetime
 
 MODEL_ARTIFACT_VERSION = "3.0"
 REQUIRED_ARTIFACT_FILES = ("runtime.yaml", "score.py")
@@ -378,6 +378,45 @@ def prepare_score_py(
         ) as f:
             f.write(scorefn_template.render(context))
 
+    def prepare_schema(self, schema_name: str):
+        """Copies schema to artifact directory.
+
+        Parameters
+        ----------
+        schema_name: str
+            The schema name
+
+        Returns
+        -------
+        None
+
+        Raises
+        ------
+        FileExistsError
+            If `schema_name` doesn't exist.
+        """
+        uri_src = os.path.join(
+            os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
+            "templates",
+            "schemas",
+            f"{schema_name}",
+        )
+
+        if not os.path.exists(uri_src):
+            raise FileExistsError(
+                f"{schema_name} does not exists. "
+                "Ensure the schema name is valid or specify a different one."
+            )
+
+        uri_dst = os.path.join(self.artifact_dir, os.path.basename(uri_src))
+
+        utils.copy_file(
+            uri_src=uri_src,
+            uri_dst=uri_dst,
+            force_overwrite=True,
+            auth=self.auth,
+        )
+
     def reload(self):
         """Syncs the `score.py` to reload the model and predict function.
 
diff --git a/ads/model/extractor/embedding_onnx_extractor.py b/ads/model/extractor/embedding_onnx_extractor.py
index becfef0a7..9f3f6b463 100644
--- a/ads/model/extractor/embedding_onnx_extractor.py
+++ b/ads/model/extractor/embedding_onnx_extractor.py
@@ -3,11 +3,35 @@
 # Copyright (c) 2024 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
 
+from ads.common.decorator.runtime_dependency import (
+    OptionalDependency,
+    runtime_dependency,
+)
 from ads.model.extractor.model_info_extractor import ModelInfoExtractor
+from ads.model.model_metadata import Framework
 
 
 class EmbeddingONNXExtractor(ModelInfoExtractor):
-    def __init__(self, model):
+    """Class that extract model metadata from EmbeddingONNXModel models.
+
+    Attributes
+    ----------
+    model: object
+        The model to extract metadata from.
+
+    Methods
+    -------
+    framework(self) -> str
+        Returns the framework of the model.
+    algorithm(self) -> object
+        Returns the algorithm of the model.
+    version(self) -> str
+        Returns the version of framework of the model.
+    hyperparameter(self) -> dict
+        Returns the hyperparameter of the model.
+    """
+
+    def __init__(self, model=None):
         self.model = model
 
     @property
@@ -19,7 +43,7 @@ def framework(self):
         str:
            The framework of the model.
         """
-        pass
+        return Framework.EMBEDDING_ONNX
 
     @property
     def algorithm(self):
@@ -30,9 +54,10 @@ def algorithm(self):
         object:
            The algorithm of the model.
         """
-        pass
+        return "Embedding_ONNX"
 
     @property
+    @runtime_dependency(module="onnxruntime", install_from=OptionalDependency.ONNX)
     def version(self):
         """Extracts the framework version of the model.
 
@@ -41,7 +66,7 @@ def version(self):
         str:
            The framework version of the model.
         """
-        pass
+        return onnxruntime.__version__
 
     @property
     def hyperparameter(self):
@@ -52,4 +77,4 @@ def hyperparameter(self):
         dict:
            The hyperparameters of the model.
         """
-        pass
+        return None
diff --git a/ads/model/framework/embedding_onnx_model.py b/ads/model/framework/embedding_onnx_model.py
index 440733f51..51dace510 100644
--- a/ads/model/framework/embedding_onnx_model.py
+++ b/ads/model/framework/embedding_onnx_model.py
@@ -3,39 +3,338 @@
 # Copyright (c) 2024 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
 
-from typing import Any, Callable, Dict, Self
+from typing import Dict, Self
 
 from ads.model.extractor.embedding_onnx_extractor import EmbeddingONNXExtractor
 from ads.model.generic_model import FrameworkSpecificModel
-from ads.model.model_properties import ModelProperties
-from ads.model.serde.common import SERDE
 
 
 class EmbeddingONNXModel(FrameworkSpecificModel):
+    """EmbeddingONNXModel class for embedding onnx model.
+
+    Attributes
+    ----------
+    algorithm: str
+        The algorithm of the model.
+    artifact_dir: str
+        Artifact directory to store the files needed for deployment.
+    auth: Dict
+        Default authentication is set using the `ads.set_auth` API. To override the
+        default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create
+        an authentication signer to instantiate an IdentityClient object.
+    framework: str
+        "embedding_onnx", the framework name of the model.
+    hyperparameter: dict
+        The hyperparameters of the estimator.
+    metadata_custom: ModelCustomMetadata
+        The model custom metadata.
+    metadata_provenance: ModelProvenanceMetadata
+        The model provenance metadata.
+    metadata_taxonomy: ModelTaxonomyMetadata
+        The model taxonomy metadata.
+    model_artifact: ModelArtifact
+        This is built by calling prepare.
+    model_deployment: ModelDeployment
+        A ModelDeployment instance.
+    model_file_name: str
+        Name of the serialized model.
+    model_id: str
+        The model ID.
+    properties: ModelProperties
+        ModelProperties object required to save and deploy model.
+    runtime_info: RuntimeInfo
+        A RuntimeInfo instance.
+    schema_input: Schema
+        Schema describes the structure of the input data.
+    schema_output: Schema
+        Schema describes the structure of the output data.
+    serialize: bool
+        Whether to serialize the model to pkl file by default. If False, you need to serialize the model manually,
+        save it under artifact_dir and update the score.py manually.
+    version: str
+        The framework version of the model.
+
+    Methods
+    -------
+    delete_deployment(...)
+        Deletes the current model deployment.
+    deploy(..., **kwargs)
+        Deploys a model.
+    from_model_artifact(uri, ..., **kwargs)
+        Loads model from the specified folder, or zip/tar archive.
+    from_model_catalog(model_id, ..., **kwargs)
+        Loads model from model catalog.
+    from_model_deployment(model_deployment_id, ..., **kwargs)
+        Loads model from model deployment.
+    update_deployment(model_deployment_id, ..., **kwargs)
+        Updates a model deployment.
+    from_id(ocid, ..., **kwargs)
+        Loads model from model OCID or model deployment OCID.
+    introspect(...)
+        Runs model introspection.
+    predict(data, ...)
+        Returns prediction of input data run against the model deployment endpoint.
+    prepare(..., **kwargs)
+        Prepare and save the score.py, serialized model and runtime.yaml file.
+    prepare_save_deploy(..., **kwargs)
+        Shortcut for prepare, save and deploy steps.
+    reload(...)
+        Reloads the model artifact files: `score.py` and the `runtime.yaml`.
+    restart_deployment(...)
+        Restarts the model deployment.
+    save(..., **kwargs)
+        Saves model artifacts to the model catalog.
+    set_model_input_serializer(serde)
+        Registers serializer used for serializing data passed in verify/predict.
+    summary_status(...)
+        Gets a summary table of the current status.
+    verify(data, ...)
+        Tests if deployment works in local environment.
+    upload_artifact(...)
+        Uploads model artifacts to the provided `uri`.
+    download_artifact(...)
+        Downloads model artifacts from the model catalog.
+    update_summary_status(...)
+        Update the status in the summary table.
+    update_summary_action(...)
+        Update the actions needed from the user in the summary table.
+
+    Examples
+    --------
+    >>> import tempfile
+    >>> import os
+    >>> import shutil
+    >>> from ads.model import EmbeddingONNXModel
+    >>> from huggingface_hub import snapshot_download
+
+    >>> local_dir=tempfile.mkdtemp()
+    >>> # download sentence-transformers/all-MiniLM-L6-v2 from huggingface
+    >>> snapshot_download(
+    ...     repo_id="sentence-transformers/all-MiniLM-L6-v2",
+    ...     local_dir=local_dir
+    ... )
+
+    >>> # copy all files from local_dir to artifact_dir
+    >>> artifact_dir = tempfile.mkdtemp()
+    >>> for root, dirs, files in os.walk(local_dir):
+    >>>     for file in files:
+    >>>         src_path = os.path.join(root, file)
+    >>>         shutil.copy(src_path, artifact_dir)
+
+    >>> model = EmbeddingONNXModel(artifact_dir=artifact_dir)
+    >>> model.summary_status()
+    >>> model.prepare(
+    ...     inference_conda_env="onnxruntime_p311_gpu_x86_64",
+    ...     inference_python_version="3.11",
+    ...     model_file_name="model.onnx",
+    ...     force_overwrite=True
+    ... )
+    >>> model.summary_status()
+    >>> model.verify(
+    ...     {
+    ...         "input": ['What are activation functions?', 'What is Deep Learning?'],
+    ...         "model": "sentence-transformers/all-MiniLM-L6-v2"
+    ...     },
+    ... )
+    >>> model.summary_status()
+    >>> model.save(display_name="sentence-transformers/all-MiniLM-L6-v2")
+    >>> model.summary_status()
+    >>> model.deploy(
+    ...    display_name="all-MiniLM-L6-v2 Embedding deployment",
+    ...    deployment_instance_shape="VM.Standard.E4.Flex",
+    ...    deployment_ocpus=20,
+    ...    deployment_memory_in_gbs=256,
+    ... )
+    >>> model.predict(
+    ...     {
+    ...         "input": ['What are activation functions?', 'What is Deep Learning?'],
+    ...         "model": "sentence-transformers/all-MiniLM-L6-v2"
+    ...     },
+    ... )
+    >>> # Uncomment the line below to delete the model and the associated model deployment
+    >>> # model.delete(delete_associated_model_deployment = True)
+    """
+
     def __init__(
         self,
-        estimator: Callable[..., Any] = None,
         artifact_dir: str | None = None,
-        properties: ModelProperties | None = None,
         auth: Dict | None = None,
-        serialize: bool = True,
-        model_save_serializer: SERDE | None = None,
-        model_input_serializer: SERDE | None = None,
+        serialize: bool = False,
         **kwargs: dict,
     ) -> Self:
+        """
+        Initiates a EmbeddingONNXModel instance.
+
+        Parameters
+        ----------
+        artifact_dir: str
+            Directory for generate artifact.
+        auth: (Dict, optional). Defaults to None.
+            The default authetication is set using `ads.set_auth` API. If you need to override the
+            default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
+            authentication signer and kwargs required to instantiate IdentityClient object.
+        serialize: bool
+            Whether to serialize the model to pkl file by default.
+            Required as `False` for embedding onnx model.
+
+        Returns
+        -------
+        EmbeddingONNXModel
+            EmbeddingONNXModel instance.
+
+        Examples
+        --------
+        >>> import tempfile
+        >>> import os
+        >>> import shutil
+        >>> from ads.model import EmbeddingONNXModel
+        >>> from huggingface_hub import snapshot_download
+
+        >>> local_dir=tempfile.mkdtemp()
+        >>> # download sentence-transformers/all-MiniLM-L6-v2 from huggingface
+        >>> snapshot_download(
+        ...     repo_id="sentence-transformers/all-MiniLM-L6-v2",
+        ...     local_dir=local_dir
+        ... )
+
+        >>> # copy all files from subdirectory to artifact_dir
+        >>> artifact_dir = tempfile.mkdtemp()
+        >>> for root, dirs, files in os.walk(local_dir):
+        >>>     for file in files:
+        >>>         src_path = os.path.join(root, file)
+        >>>         shutil.copy(src_path, artifact_dir)
+
+        >>> model = EmbeddingONNXModel(artifact_dir=artifact_dir)
+        >>> model.summary_status()
+        >>> model.prepare(
+        ...     inference_conda_env="onnxruntime_p311_gpu_x86_64",
+        ...     inference_python_version="3.11",
+        ...     model_file_name="model.onnx",
+        ...     force_overwrite=True
+        ... )
+        >>> model.summary_status()
+        >>> model.verify(
+        ...     {
+        ...         "input": ['What are activation functions?', 'What is Deep Learning?'],
+        ...         "model": "sentence-transformers/all-MiniLM-L6-v2"
+        ...     },
+        ... )
+        >>> model.summary_status()
+        >>> model.save(display_name="sentence-transformers/all-MiniLM-L6-v2")
+        >>> model.summary_status()
+        >>> model.deploy(
+        ...    display_name="all-MiniLM-L6-v2 Embedding deployment",
+        ...    deployment_instance_shape="VM.Standard.E4.Flex",
+        ...    deployment_ocpus=20,
+        ...    deployment_memory_in_gbs=256,
+        ... )
+        >>> model.predict(
+        ...     {
+        ...         "input": ['What are activation functions?', 'What is Deep Learning?'],
+        ...         "model": "sentence-transformers/all-MiniLM-L6-v2"
+        ...     },
+        ... )
+        >>> # Uncomment the line below to delete the model and the associated model deployment
+        >>> # model.delete(delete_associated_model_deployment = True)
+        """
         super().__init__(
-            estimator,
-            artifact_dir,
-            properties,
-            auth,
-            serialize,
-            model_save_serializer,
-            model_input_serializer,
+            artifact_dir=artifact_dir,
+            auth=auth,
+            serialize=serialize,
             **kwargs,
         )
 
-        self._extractor = EmbeddingONNXExtractor(estimator)
+        self._extractor = EmbeddingONNXExtractor()
         self.framework = self._extractor.framework
         self.algorithm = self._extractor.algorithm
         self.version = self._extractor.version
         self.hyperparameter = self._extractor.hyperparameter
+
+    def verify(
+        self, data=None, reload_artifacts=True, auto_serialize_data=False, **kwargs
+    ):
+        """Test if embedding onnx model deployment works in local environment.
+
+        Examples
+        --------
+        >>> data = {
+        ...     "input": ['What are activation functions?', 'What is Deep Learning?'],
+        ...     "model": "sentence-transformers/all-MiniLM-L6-v2"
+        ... }
+        >>> prediction = model.verify(data)
+
+        Parameters
+        ----------
+        data: Any
+            Data used to test if deployment works in local environment.
+        reload_artifacts: bool. Defaults to True.
+            Whether to reload artifacts or not.
+        auto_serialize_data: bool.
+            Whether to auto serialize input data. Required as `False` for embedding onnx model.
+            Input `data` must be json serializable.
+        kwargs:
+            content_type: str, used to indicate the media type of the resource.
+            image: PIL.Image Object or uri for the image.
+               A valid string path for image file can be local path, http(s), oci, s3, gs.
+            storage_options: dict
+               Passed to `fsspec.open` for a particular storage connection.
+               Please see `fsspec` (https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.open) for more details.
+
+        Returns
+        -------
+        Dict
+            A dictionary which contains prediction results.
+        """
+        if auto_serialize_data:
+            raise ValueError(
+                "ADS will not auto serialize `data` for embedding onnx model. "
+                "Input json serializable `data` and set `auto_serialize_data` as False."
+            )
+
+        return super().verify(
+            data=data,
+            reload_artifacts=reload_artifacts,
+            auto_serialize_data=auto_serialize_data,
+            **kwargs,
+        )
+
+    def predict(self, data=None, auto_serialize_data=False, **kwargs):
+        """Returns prediction of input data run against the embedding onnx model deployment endpoint.
+
+        Examples
+        --------
+        >>> data = {
+        ...     "input": ['What are activation functions?', 'What is Deep Learning?'],
+        ...     "model": "sentence-transformers/all-MiniLM-L6-v2"
+        ... }
+        >>> prediction = model.predict(data)
+
+        Parameters
+        ----------
+        data: Any
+            Data for the prediction for model deployment.
+        auto_serialize_data: bool.
+            Whether to auto serialize input data. Required as `False` for embedding onnx model.
+            Input `data` must be json serializable.
+        kwargs:
+            content_type: str, used to indicate the media type of the resource.
+            image: PIL.Image Object or uri for the image.
+               A valid string path for image file can be local path, http(s), oci, s3, gs.
+            storage_options: dict
+               Passed to `fsspec.open` for a particular storage connection.
+               Please see `fsspec` (https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.open) for more details.
+
+        Returns
+        -------
+        Dict[str, Any]
+            Dictionary with the predicted values.
+        """
+        if auto_serialize_data:
+            raise ValueError(
+                "ADS will not auto serialize `data` for embedding onnx model. "
+                "Input json serializable `data` and set `auto_serialize_data` as False."
+            )
+
+        return super().predict(
+            data=data, auto_serialize_data=auto_serialize_data, **kwargs
+        )
diff --git a/ads/model/generic_model.py b/ads/model/generic_model.py
index 842ae94a8..edead1252 100644
--- a/ads/model/generic_model.py
+++ b/ads/model/generic_model.py
@@ -1,5 +1,4 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*--
 
 # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -8,9 +7,9 @@
 import os
 import shutil
 import tempfile
+import warnings
 from enum import Enum
 from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
-import warnings
 
 import numpy as np
 import pandas as pd
@@ -21,8 +20,8 @@
 from ads.common import auth as authutil
 from ads.common import logger, utils
 from ads.common.decorator.utils import class_or_instance_method
-from ads.common.utils import DATA_SCHEMA_MAX_COL_NUM, get_files
 from ads.common.object_storage_details import ObjectStorageDetails
+from ads.common.utils import DATA_SCHEMA_MAX_COL_NUM, get_files
 from ads.config import (
     CONDA_BUCKET_NS,
     JOB_RUN_COMPARTMENT_OCID,
@@ -49,11 +48,11 @@
     DEFAULT_POLL_INTERVAL,
     DEFAULT_WAIT_TIME,
     ModelDeployment,
-    ModelDeploymentMode,
-    ModelDeploymentProperties,
     ModelDeploymentCondaRuntime,
-    ModelDeploymentInfrastructure,
     ModelDeploymentContainerRuntime,
+    ModelDeploymentInfrastructure,
+    ModelDeploymentMode,
+    ModelDeploymentProperties,
 )
 from ads.model.deployment.common.utils import State as ModelDeploymentState
 from ads.model.deployment.common.utils import send_request
@@ -66,10 +65,10 @@
 from ads.model.model_metadata import (
     ExtendedEnumMeta,
     Framework,
+    MetadataCustomCategory,
     ModelCustomMetadata,
     ModelProvenanceMetadata,
     ModelTaxonomyMetadata,
-    MetadataCustomCategory,
 )
 from ads.model.model_metadata_mixin import MetadataMixin
 from ads.model.model_properties import ModelProperties
@@ -940,11 +939,10 @@ def prepare(
                 manifest = fetch_manifest_from_conda_location(conda_prefix)
                 if "pack_path" in manifest:
                     self.properties.inference_conda_env = manifest["pack_path"]
-                else:
-                    if not self.ignore_conda_error:
-                        raise ValueError(
-                            "`inference_conda_env` must be specified for conda runtime. If you are using container runtime, set `ignore_conda_error=True`."
-                        )
+                elif not self.ignore_conda_error:
+                    raise ValueError(
+                        "`inference_conda_env` must be specified for conda runtime. If you are using container runtime, set `ignore_conda_error=True`."
+                    )
                 self.properties.inference_python_version = (
                     manifest["python"]
                     if "python" in manifest
@@ -1025,7 +1023,7 @@ def prepare(
                     detail=PREPARE_STATUS_SERIALIZE_MODEL_DETAIL,
                     status=ModelState.DONE.value,
                 )
-            except SerializeModelNotImplementedError as e:
+            except SerializeModelNotImplementedError:
                 if not utils.is_path_exists(
                     uri=os.path.join(self.artifact_dir, self.model_file_name),
                     auth=self.auth,
@@ -1056,17 +1054,19 @@ def prepare(
             except Exception as e:
                 raise e
 
+        if self.framework == Framework.EMBEDDING_ONNX:
+            self.model_artifact.prepare_schema(schema_name="openapi.json")
+
         if as_onnx:
             jinja_template_filename = "score_onnx_new"
+        elif self.framework and self.framework != "other":
+            jinja_template_filename = "score_" + self.framework
+            if self.framework == "transformers":
+                jinja_template_filename = "score_" + "huggingface_pipeline"
         else:
-            if self.framework and self.framework != "other":
-                jinja_template_filename = "score_" + self.framework
-                if self.framework == "transformers":
-                    jinja_template_filename = "score_" + "huggingface_pipeline"
-            else:
-                jinja_template_filename = (
-                    "score-pkl" if self._serialize else "score_generic"
-                )
+            jinja_template_filename = (
+                "score-pkl" if self._serialize else "score_generic"
+            )
 
         if score_py_uri:
             utils.copy_file(
@@ -1276,7 +1276,7 @@ def verify(
         if self.model_artifact is None:
             raise ArtifactsNotAvailableError
 
-        endpoint = f"http://127.0.0.1:8000/predict"
+        endpoint = "http://127.0.0.1:8000/predict"
         data = self._handle_input_data(data, auto_serialize_data, **kwargs)
 
         request_body = send_request(
@@ -2179,7 +2179,7 @@ def save(
                 )
                 self.update_summary_action(
                     detail=SAVE_STATUS_INTROSPECT_TEST_DETAIL,
-                    action=f"Use `.introspect()` method to get detailed information.",
+                    action="Use `.introspect()` method to get detailed information.",
                 )
                 raise IntrospectionNotPassed(msg)
             else:
@@ -2470,7 +2470,9 @@ def deploy(
             .with_shape_name(self.properties.deployment_instance_shape)
             .with_replica(self.properties.deployment_instance_count)
             .with_subnet_id(self.properties.deployment_instance_subnet_id)
-            .with_private_endpoint_id(self.properties.deployment_instance_private_endpoint_id)
+            .with_private_endpoint_id(
+                self.properties.deployment_instance_private_endpoint_id
+            )
         )
 
         web_concurrency = (
diff --git a/ads/model/model_metadata.py b/ads/model/model_metadata.py
index 2667b82ad..afe143190 100644
--- a/ads/model/model_metadata.py
+++ b/ads/model/model_metadata.py
@@ -1,5 +1,4 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*--
 
 # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -11,20 +10,21 @@
 from abc import ABC, abstractmethod
 from dataclasses import dataclass, field, fields
 from pathlib import Path
-from typing import Dict, List, Tuple, Union, Optional, Any
+from typing import Any, Dict, List, Optional, Tuple, Union
 
-import ads.dataset.factory as factory
 import fsspec
 import git
 import oci.data_science.models
 import pandas as pd
 import yaml
+from oci.util import to_dict
+
 from ads.common import logger
 from ads.common.error import ChangesNotCommitted
 from ads.common.extended_enum import ExtendedEnumMeta
-from ads.common.serializer import DataClassSerializable
 from ads.common.object_storage_details import ObjectStorageDetails
-from oci.util import to_dict
+from ads.common.serializer import DataClassSerializable
+from ads.dataset import factory
 
 try:
     from yaml import CDumper as dumper
@@ -173,6 +173,7 @@ class Framework(str, metaclass=ExtendedEnumMeta):
     WORD2VEC = "word2vec"
     ENSEMBLE = "ensemble"
     SPARK = "pyspark"
+    EMBEDDING_ONNX = "embedding_onnx"
     OTHER = "other"
 
 
@@ -1398,7 +1399,7 @@ def from_dict(cls, data: Dict) -> "ModelCustomMetadata":
         if (
             not data
             or not isinstance(data, Dict)
-            or not "data" in data
+            or "data" not in data
             or not isinstance(data["data"], List)
         ):
             raise ValueError(
@@ -1550,7 +1551,7 @@ def from_dict(cls, data: Dict) -> "ModelTaxonomyMetadata":
         if (
             not data
             or not isinstance(data, Dict)
-            or not "data" in data
+            or "data" not in data
             or not isinstance(data["data"], List)
         ):
             raise ValueError(
diff --git a/ads/templates/schemas/openapi.json b/ads/templates/schemas/openapi.json
new file mode 100644
index 000000000..672c937e7
--- /dev/null
+++ b/ads/templates/schemas/openapi.json
@@ -0,0 +1,1740 @@
+{
+  "components": {
+    "schemas": {
+      "ClassifierModel": {
+        "properties": {
+          "id2label": {
+            "additionalProperties": {
+              "type": "string"
+            },
+            "example": {
+              "0": "LABEL"
+            },
+            "type": "object"
+          },
+          "label2id": {
+            "additionalProperties": {
+              "minimum": 0,
+              "type": "integer"
+            },
+            "example": {
+              "LABEL": 0
+            },
+            "type": "object"
+          }
+        },
+        "required": [
+          "id2label",
+          "label2id"
+        ],
+        "type": "object"
+      },
+      "DecodeRequest": {
+        "properties": {
+          "ids": {
+            "$ref": "#/components/schemas/InputIds"
+          },
+          "skip_special_tokens": {
+            "default": "true",
+            "example": "true",
+            "type": "boolean"
+          }
+        },
+        "required": [
+          "ids"
+        ],
+        "type": "object"
+      },
+      "DecodeResponse": {
+        "example": [
+          "test"
+        ],
+        "items": {
+          "type": "string"
+        },
+        "type": "array"
+      },
+      "EmbedAllRequest": {
+        "properties": {
+          "inputs": {
+            "$ref": "#/components/schemas/Input"
+          },
+          "prompt_name": {
+            "default": "null",
+            "description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `sentence-transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.",
+            "example": "null",
+            "nullable": true,
+            "type": "string"
+          },
+          "truncate": {
+            "default": "false",
+            "example": "false",
+            "nullable": true,
+            "type": "boolean"
+          },
+          "truncation_direction": {
+            "allOf": [
+              {
+                "$ref": "#/components/schemas/TruncationDirection"
+              }
+            ],
+            "default": "right"
+          }
+        },
+        "required": [
+          "inputs"
+        ],
+        "type": "object"
+      },
+      "EmbedAllResponse": {
+        "example": [
+          [
+            [
+              0.0,
+              1.0,
+              2.0
+            ]
+          ]
+        ],
+        "items": {
+          "items": {
+            "items": {
+              "format": "float",
+              "type": "number"
+            },
+            "type": "array"
+          },
+          "type": "array"
+        },
+        "type": "array"
+      },
+      "EmbedRequest": {
+        "properties": {
+          "inputs": {
+            "$ref": "#/components/schemas/Input"
+          },
+          "normalize": {
+            "default": "true",
+            "example": "true",
+            "type": "boolean"
+          },
+          "prompt_name": {
+            "default": "null",
+            "description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `sentence-transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.",
+            "example": "null",
+            "nullable": true,
+            "type": "string"
+          },
+          "truncate": {
+            "default": "false",
+            "example": "false",
+            "nullable": true,
+            "type": "boolean"
+          },
+          "truncation_direction": {
+            "allOf": [
+              {
+                "$ref": "#/components/schemas/TruncationDirection"
+              }
+            ],
+            "default": "right"
+          }
+        },
+        "required": [
+          "inputs"
+        ],
+        "type": "object"
+      },
+      "EmbedResponse": {
+        "example": [
+          [
+            0.0,
+            1.0,
+            2.0
+          ]
+        ],
+        "items": {
+          "items": {
+            "format": "float",
+            "type": "number"
+          },
+          "type": "array"
+        },
+        "type": "array"
+      },
+      "EmbedSparseRequest": {
+        "properties": {
+          "inputs": {
+            "$ref": "#/components/schemas/Input"
+          },
+          "prompt_name": {
+            "default": "null",
+            "description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `sentence-transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.",
+            "example": "null",
+            "nullable": true,
+            "type": "string"
+          },
+          "truncate": {
+            "default": "false",
+            "example": "false",
+            "nullable": true,
+            "type": "boolean"
+          },
+          "truncation_direction": {
+            "allOf": [
+              {
+                "$ref": "#/components/schemas/TruncationDirection"
+              }
+            ],
+            "default": "right"
+          }
+        },
+        "required": [
+          "inputs"
+        ],
+        "type": "object"
+      },
+      "EmbedSparseResponse": {
+        "items": {
+          "items": {
+            "$ref": "#/components/schemas/SparseValue"
+          },
+          "type": "array"
+        },
+        "type": "array"
+      },
+      "Embedding": {
+        "oneOf": [
+          {
+            "items": {
+              "format": "float",
+              "type": "number"
+            },
+            "type": "array"
+          },
+          {
+            "type": "string"
+          }
+        ]
+      },
+      "EmbeddingModel": {
+        "properties": {
+          "pooling": {
+            "example": "cls",
+            "type": "string"
+          }
+        },
+        "required": [
+          "pooling"
+        ],
+        "type": "object"
+      },
+      "EncodingFormat": {
+        "enum": [
+          "float",
+          "base64"
+        ],
+        "type": "string"
+      },
+      "ErrorResponse": {
+        "properties": {
+          "error": {
+            "type": "string"
+          },
+          "error_type": {
+            "$ref": "#/components/schemas/ErrorType"
+          }
+        },
+        "required": [
+          "error",
+          "error_type"
+        ],
+        "type": "object"
+      },
+      "ErrorType": {
+        "enum": [
+          "Unhealthy",
+          "Backend",
+          "Overloaded",
+          "Validation",
+          "Tokenizer"
+        ],
+        "type": "string"
+      },
+      "Info": {
+        "properties": {
+          "auto_truncate": {
+            "type": "boolean"
+          },
+          "docker_label": {
+            "example": "null",
+            "nullable": true,
+            "type": "string"
+          },
+          "max_batch_requests": {
+            "default": "null",
+            "example": "null",
+            "minimum": 0,
+            "nullable": true,
+            "type": "integer"
+          },
+          "max_batch_tokens": {
+            "example": "2048",
+            "minimum": 0,
+            "type": "integer"
+          },
+          "max_client_batch_size": {
+            "example": "32",
+            "minimum": 0,
+            "type": "integer"
+          },
+          "max_concurrent_requests": {
+            "description": "Router Parameters",
+            "example": "128",
+            "minimum": 0,
+            "type": "integer"
+          },
+          "max_input_length": {
+            "example": "512",
+            "minimum": 0,
+            "type": "integer"
+          },
+          "model_dtype": {
+            "example": "float16",
+            "type": "string"
+          },
+          "model_id": {
+            "description": "Model info",
+            "example": "thenlper/gte-base",
+            "type": "string"
+          },
+          "model_sha": {
+            "example": "fca14538aa9956a46526bd1d0d11d69e19b5a101",
+            "nullable": true,
+            "type": "string"
+          },
+          "model_type": {
+            "$ref": "#/components/schemas/ModelType"
+          },
+          "sha": {
+            "example": "null",
+            "nullable": true,
+            "type": "string"
+          },
+          "tokenization_workers": {
+            "example": "4",
+            "minimum": 0,
+            "type": "integer"
+          },
+          "version": {
+            "description": "Router Info",
+            "example": "0.5.0",
+            "type": "string"
+          }
+        },
+        "required": [
+          "model_id",
+          "model_dtype",
+          "model_type",
+          "max_concurrent_requests",
+          "max_input_length",
+          "max_batch_tokens",
+          "max_client_batch_size",
+          "auto_truncate",
+          "tokenization_workers",
+          "version"
+        ],
+        "type": "object"
+      },
+      "Input": {
+        "oneOf": [
+          {
+            "$ref": "#/components/schemas/InputType"
+          },
+          {
+            "items": {
+              "$ref": "#/components/schemas/InputType"
+            },
+            "type": "array"
+          }
+        ]
+      },
+      "InputIds": {
+        "oneOf": [
+          {
+            "items": {
+              "format": "int32",
+              "minimum": 0,
+              "type": "integer"
+            },
+            "type": "array"
+          },
+          {
+            "items": {
+              "items": {
+                "format": "int32",
+                "minimum": 0,
+                "type": "integer"
+              },
+              "type": "array"
+            },
+            "type": "array"
+          }
+        ]
+      },
+      "InputType": {
+        "oneOf": [
+          {
+            "type": "string"
+          },
+          {
+            "items": {
+              "format": "int32",
+              "minimum": 0,
+              "type": "integer"
+            },
+            "type": "array"
+          }
+        ]
+      },
+      "ModelType": {
+        "oneOf": [
+          {
+            "properties": {
+              "classifier": {
+                "$ref": "#/components/schemas/ClassifierModel"
+              }
+            },
+            "required": [
+              "classifier"
+            ],
+            "type": "object"
+          },
+          {
+            "properties": {
+              "embedding": {
+                "$ref": "#/components/schemas/EmbeddingModel"
+              }
+            },
+            "required": [
+              "embedding"
+            ],
+            "type": "object"
+          },
+          {
+            "properties": {
+              "reranker": {
+                "$ref": "#/components/schemas/ClassifierModel"
+              }
+            },
+            "required": [
+              "reranker"
+            ],
+            "type": "object"
+          }
+        ]
+      },
+      "OpenAICompatEmbedding": {
+        "properties": {
+          "embedding": {
+            "$ref": "#/components/schemas/Embedding"
+          },
+          "index": {
+            "example": "0",
+            "minimum": 0,
+            "type": "integer"
+          },
+          "object": {
+            "example": "embedding",
+            "type": "string"
+          }
+        },
+        "required": [
+          "object",
+          "embedding",
+          "index"
+        ],
+        "type": "object"
+      },
+      "OpenAICompatErrorResponse": {
+        "properties": {
+          "code": {
+            "format": "int32",
+            "minimum": 0,
+            "type": "integer"
+          },
+          "error_type": {
+            "$ref": "#/components/schemas/ErrorType"
+          },
+          "message": {
+            "type": "string"
+          }
+        },
+        "required": [
+          "message",
+          "code",
+          "error_type"
+        ],
+        "type": "object"
+      },
+      "OpenAICompatRequest": {
+        "properties": {
+          "encoding_format": {
+            "allOf": [
+              {
+                "$ref": "#/components/schemas/EncodingFormat"
+              }
+            ],
+            "default": "float"
+          },
+          "input": {
+            "$ref": "#/components/schemas/Input"
+          },
+          "model": {
+            "example": "null",
+            "nullable": true,
+            "type": "string"
+          },
+          "user": {
+            "example": "null",
+            "nullable": true,
+            "type": "string"
+          }
+        },
+        "required": [
+          "input"
+        ],
+        "type": "object"
+      },
+      "OpenAICompatResponse": {
+        "properties": {
+          "data": {
+            "items": {
+              "$ref": "#/components/schemas/OpenAICompatEmbedding"
+            },
+            "type": "array"
+          },
+          "model": {
+            "example": "thenlper/gte-base",
+            "type": "string"
+          },
+          "object": {
+            "example": "list",
+            "type": "string"
+          },
+          "usage": {
+            "$ref": "#/components/schemas/OpenAICompatUsage"
+          }
+        },
+        "required": [
+          "object",
+          "data",
+          "model",
+          "usage"
+        ],
+        "type": "object"
+      },
+      "OpenAICompatUsage": {
+        "properties": {
+          "prompt_tokens": {
+            "example": "512",
+            "minimum": 0,
+            "type": "integer"
+          },
+          "total_tokens": {
+            "example": "512",
+            "minimum": 0,
+            "type": "integer"
+          }
+        },
+        "required": [
+          "prompt_tokens",
+          "total_tokens"
+        ],
+        "type": "object"
+      },
+      "PredictInput": {
+        "description": "Model input. Can be either a single string, a pair of strings or a batch of mixed single and pairs of strings.",
+        "example": "What is Deep Learning?",
+        "oneOf": [
+          {
+            "description": "A single string",
+            "type": "string"
+          },
+          {
+            "description": "A pair of strings",
+            "items": {
+              "type": "string"
+            },
+            "maxItems": 2,
+            "minItems": 2,
+            "type": "array"
+          },
+          {
+            "description": "A batch",
+            "items": {
+              "oneOf": [
+                {
+                  "description": "A single string",
+                  "items": {
+                    "type": "string"
+                  },
+                  "maxItems": 1,
+                  "minItems": 1,
+                  "type": "array"
+                },
+                {
+                  "description": "A pair of strings",
+                  "items": {
+                    "type": "string"
+                  },
+                  "maxItems": 2,
+                  "minItems": 2,
+                  "type": "array"
+                }
+              ]
+            },
+            "type": "array"
+          }
+        ]
+      },
+      "PredictRequest": {
+        "properties": {
+          "inputs": {
+            "$ref": "#/components/schemas/PredictInput"
+          },
+          "raw_scores": {
+            "default": "false",
+            "example": "false",
+            "type": "boolean"
+          },
+          "truncate": {
+            "default": "false",
+            "example": "false",
+            "nullable": true,
+            "type": "boolean"
+          },
+          "truncation_direction": {
+            "allOf": [
+              {
+                "$ref": "#/components/schemas/TruncationDirection"
+              }
+            ],
+            "default": "right"
+          }
+        },
+        "required": [
+          "inputs"
+        ],
+        "type": "object"
+      },
+      "PredictResponse": {
+        "oneOf": [
+          {
+            "items": {
+              "$ref": "#/components/schemas/Prediction"
+            },
+            "type": "array"
+          },
+          {
+            "items": {
+              "items": {
+                "$ref": "#/components/schemas/Prediction"
+              },
+              "type": "array"
+            },
+            "type": "array"
+          }
+        ]
+      },
+      "Prediction": {
+        "properties": {
+          "label": {
+            "example": "admiration",
+            "type": "string"
+          },
+          "score": {
+            "example": "0.5",
+            "format": "float",
+            "type": "number"
+          }
+        },
+        "required": [
+          "score",
+          "label"
+        ],
+        "type": "object"
+      },
+      "Rank": {
+        "properties": {
+          "index": {
+            "example": "0",
+            "minimum": 0,
+            "type": "integer"
+          },
+          "score": {
+            "example": "1.0",
+            "format": "float",
+            "type": "number"
+          },
+          "text": {
+            "default": "null",
+            "example": "Deep Learning is ...",
+            "nullable": true,
+            "type": "string"
+          }
+        },
+        "required": [
+          "index",
+          "score"
+        ],
+        "type": "object"
+      },
+      "RerankRequest": {
+        "properties": {
+          "query": {
+            "example": "What is Deep Learning?",
+            "type": "string"
+          },
+          "raw_scores": {
+            "default": "false",
+            "example": "false",
+            "type": "boolean"
+          },
+          "return_text": {
+            "default": "false",
+            "example": "false",
+            "type": "boolean"
+          },
+          "texts": {
+            "example": [
+              "Deep Learning is ..."
+            ],
+            "items": {
+              "type": "string"
+            },
+            "type": "array"
+          },
+          "truncate": {
+            "default": "false",
+            "example": "false",
+            "nullable": true,
+            "type": "boolean"
+          },
+          "truncation_direction": {
+            "allOf": [
+              {
+                "$ref": "#/components/schemas/TruncationDirection"
+              }
+            ],
+            "default": "right"
+          }
+        },
+        "required": [
+          "query",
+          "texts"
+        ],
+        "type": "object"
+      },
+      "RerankResponse": {
+        "items": {
+          "$ref": "#/components/schemas/Rank"
+        },
+        "type": "array"
+      },
+      "SimilarityInput": {
+        "properties": {
+          "sentences": {
+            "description": "A list of strings which will be compared against the source_sentence.",
+            "example": [
+              "What is Machine Learning?"
+            ],
+            "items": {
+              "type": "string"
+            },
+            "type": "array"
+          },
+          "source_sentence": {
+            "description": "The string that you wish to compare the other strings with. This can be a phrase, sentence,\nor longer passage, depending on the model being used.",
+            "example": "What is Deep Learning?",
+            "type": "string"
+          }
+        },
+        "required": [
+          "source_sentence",
+          "sentences"
+        ],
+        "type": "object"
+      },
+      "SimilarityParameters": {
+        "properties": {
+          "prompt_name": {
+            "default": "null",
+            "description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `sentence-transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.",
+            "example": "null",
+            "nullable": true,
+            "type": "string"
+          },
+          "truncate": {
+            "default": "false",
+            "example": "false",
+            "nullable": true,
+            "type": "boolean"
+          },
+          "truncation_direction": {
+            "allOf": [
+              {
+                "$ref": "#/components/schemas/TruncationDirection"
+              }
+            ],
+            "default": "right"
+          }
+        },
+        "required": [
+          "truncation_direction"
+        ],
+        "type": "object"
+      },
+      "SimilarityRequest": {
+        "properties": {
+          "inputs": {
+            "$ref": "#/components/schemas/SimilarityInput"
+          },
+          "parameters": {
+            "allOf": [
+              {
+                "$ref": "#/components/schemas/SimilarityParameters"
+              }
+            ],
+            "default": "null",
+            "nullable": true
+          }
+        },
+        "required": [
+          "inputs"
+        ],
+        "type": "object"
+      },
+      "SimilarityResponse": {
+        "example": [
+          0.0,
+          1.0,
+          0.5
+        ],
+        "items": {
+          "format": "float",
+          "type": "number"
+        },
+        "type": "array"
+      },
+      "SimpleToken": {
+        "properties": {
+          "id": {
+            "example": 0,
+            "format": "int32",
+            "minimum": 0,
+            "type": "integer"
+          },
+          "special": {
+            "example": "false",
+            "type": "boolean"
+          },
+          "start": {
+            "example": 0,
+            "minimum": 0,
+            "nullable": true,
+            "type": "integer"
+          },
+          "stop": {
+            "example": 2,
+            "minimum": 0,
+            "nullable": true,
+            "type": "integer"
+          },
+          "text": {
+            "example": "test",
+            "type": "string"
+          }
+        },
+        "required": [
+          "id",
+          "text",
+          "special"
+        ],
+        "type": "object"
+      },
+      "SparseValue": {
+        "properties": {
+          "index": {
+            "minimum": 0,
+            "type": "integer"
+          },
+          "value": {
+            "format": "float",
+            "type": "number"
+          }
+        },
+        "required": [
+          "index",
+          "value"
+        ],
+        "type": "object"
+      },
+      "TokenizeInput": {
+        "oneOf": [
+          {
+            "type": "string"
+          },
+          {
+            "items": {
+              "type": "string"
+            },
+            "type": "array"
+          }
+        ]
+      },
+      "TokenizeRequest": {
+        "properties": {
+          "add_special_tokens": {
+            "default": "true",
+            "example": "true",
+            "type": "boolean"
+          },
+          "inputs": {
+            "$ref": "#/components/schemas/TokenizeInput"
+          },
+          "prompt_name": {
+            "default": "null",
+            "description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `sentence-transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.",
+            "example": "null",
+            "nullable": true,
+            "type": "string"
+          }
+        },
+        "required": [
+          "inputs"
+        ],
+        "type": "object"
+      },
+      "TokenizeResponse": {
+        "example": [
+          [
+            {
+              "id": 0,
+              "special": false,
+              "start": 0,
+              "stop": 2,
+              "text": "test"
+            }
+          ]
+        ],
+        "items": {
+          "items": {
+            "$ref": "#/components/schemas/SimpleToken"
+          },
+          "type": "array"
+        },
+        "type": "array"
+      },
+      "TruncationDirection": {
+        "enum": [
+          "Left",
+          "Right"
+        ],
+        "type": "string"
+      }
+    }
+  },
+  "info": {
+    "contact": {
+      "name": "Olivier Dehaene"
+    },
+    "description": "Text Embedding Webserver",
+    "license": {
+      "name": "Apache 2.0",
+      "url": "https://www.apache.org/licenses/LICENSE-2.0"
+    },
+    "title": "Text Embeddings Inference",
+    "version": "1.6.0"
+  },
+  "openapi": "3.0.3",
+  "paths": {
+    "/decode": {
+      "post": {
+        "operationId": "decode",
+        "requestBody": {
+          "content": {
+            "application/json": {
+              "schema": {
+                "$ref": "#/components/schemas/DecodeRequest"
+              }
+            }
+          },
+          "required": true
+        },
+        "responses": {
+          "200": {
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/DecodeResponse"
+                }
+              }
+            },
+            "description": "Decoded ids"
+          },
+          "422": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "message": "Tokenization error",
+                  "type": "tokenizer"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Tokenization error"
+          }
+        },
+        "summary": "Decode input ids",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    },
+    "/embed": {
+      "post": {
+        "operationId": "embed",
+        "requestBody": {
+          "content": {
+            "application/json": {
+              "schema": {
+                "$ref": "#/components/schemas/EmbedRequest"
+              }
+            }
+          },
+          "required": true
+        },
+        "responses": {
+          "200": {
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/EmbedResponse"
+                }
+              }
+            },
+            "description": "Embeddings"
+          },
+          "413": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Batch size error",
+                  "error_type": "validation"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Batch size error"
+          },
+          "422": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Tokenization error",
+                  "error_type": "tokenizer"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Tokenization error"
+          },
+          "424": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Inference failed",
+                  "error_type": "backend"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Embedding Error"
+          },
+          "429": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Model is overloaded",
+                  "error_type": "overloaded"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Model is overloaded"
+          }
+        },
+        "summary": "Get Embeddings. Returns a 424 status code if the model is not an embedding model.",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    },
+    "/embed_all": {
+      "post": {
+        "description": "Returns a 424 status code if the model is not an embedding model.",
+        "operationId": "embed_all",
+        "requestBody": {
+          "content": {
+            "application/json": {
+              "schema": {
+                "$ref": "#/components/schemas/EmbedAllRequest"
+              }
+            }
+          },
+          "required": true
+        },
+        "responses": {
+          "200": {
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/EmbedAllResponse"
+                }
+              }
+            },
+            "description": "Embeddings"
+          },
+          "413": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Batch size error",
+                  "error_type": "validation"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Batch size error"
+          },
+          "422": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Tokenization error",
+                  "error_type": "tokenizer"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Tokenization error"
+          },
+          "424": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Inference failed",
+                  "error_type": "backend"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Embedding Error"
+          },
+          "429": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Model is overloaded",
+                  "error_type": "overloaded"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Model is overloaded"
+          }
+        },
+        "summary": "Get all Embeddings without Pooling.",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    },
+    "/embed_sparse": {
+      "post": {
+        "operationId": "embed_sparse",
+        "requestBody": {
+          "content": {
+            "application/json": {
+              "schema": {
+                "$ref": "#/components/schemas/EmbedSparseRequest"
+              }
+            }
+          },
+          "required": true
+        },
+        "responses": {
+          "200": {
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/EmbedSparseResponse"
+                }
+              }
+            },
+            "description": "Embeddings"
+          },
+          "413": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Batch size error",
+                  "error_type": "validation"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Batch size error"
+          },
+          "422": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Tokenization error",
+                  "error_type": "tokenizer"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Tokenization error"
+          },
+          "424": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Inference failed",
+                  "error_type": "backend"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Embedding Error"
+          },
+          "429": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Model is overloaded",
+                  "error_type": "overloaded"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Model is overloaded"
+          }
+        },
+        "summary": "Get Sparse Embeddings. Returns a 424 status code if the model is not an embedding model with SPLADE pooling.",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    },
+    "/health": {
+      "get": {
+        "operationId": "health",
+        "responses": {
+          "200": {
+            "description": "Everything is working fine"
+          },
+          "503": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "unhealthy",
+                  "error_type": "unhealthy"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Text embeddings Inference is down"
+          }
+        },
+        "summary": "Health check method",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    },
+    "/info": {
+      "get": {
+        "operationId": "get_model_info",
+        "responses": {
+          "200": {
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/Info"
+                }
+              }
+            },
+            "description": "Served model info"
+          }
+        },
+        "summary": "Text Embeddings Inference endpoint info",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    },
+    "/metrics": {
+      "get": {
+        "operationId": "metrics",
+        "responses": {
+          "200": {
+            "content": {
+              "text/plain": {
+                "schema": {
+                  "type": "string"
+                }
+              }
+            },
+            "description": "Prometheus Metrics"
+          }
+        },
+        "summary": "Prometheus metrics scrape endpoint",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    },
+    "/predict": {
+      "post": {
+        "operationId": "predict",
+        "requestBody": {
+          "content": {
+            "application/json": {
+              "schema": {
+                "$ref": "#/components/schemas/PredictRequest"
+              }
+            }
+          },
+          "required": true
+        },
+        "responses": {
+          "200": {
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/PredictResponse"
+                }
+              }
+            },
+            "description": "Predictions"
+          },
+          "413": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Batch size error",
+                  "error_type": "validation"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Batch size error"
+          },
+          "422": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Tokenization error",
+                  "error_type": "tokenizer"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Tokenization error"
+          },
+          "424": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Inference failed",
+                  "error_type": "backend"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Prediction Error"
+          },
+          "429": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Model is overloaded",
+                  "error_type": "overloaded"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Model is overloaded"
+          }
+        },
+        "summary": "Get Predictions. Returns a 424 status code if the model is not a Sequence Classification model",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    },
+    "/rerank": {
+      "post": {
+        "description": "a single class.",
+        "operationId": "rerank",
+        "requestBody": {
+          "content": {
+            "application/json": {
+              "schema": {
+                "$ref": "#/components/schemas/RerankRequest"
+              }
+            }
+          },
+          "required": true
+        },
+        "responses": {
+          "200": {
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/RerankResponse"
+                }
+              }
+            },
+            "description": "Ranks"
+          },
+          "413": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Batch size error",
+                  "error_type": "validation"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Batch size error"
+          },
+          "422": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Tokenization error",
+                  "error_type": "tokenizer"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Tokenization error"
+          },
+          "424": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Inference failed",
+                  "error_type": "backend"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Rerank Error"
+          },
+          "429": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Model is overloaded",
+                  "error_type": "overloaded"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Model is overloaded"
+          }
+        },
+        "summary": "Get Ranks. Returns a 424 status code if the model is not a Sequence Classification model with",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    },
+    "/similarity": {
+      "post": {
+        "operationId": "similarity",
+        "requestBody": {
+          "content": {
+            "application/json": {
+              "schema": {
+                "$ref": "#/components/schemas/SimilarityRequest"
+              }
+            }
+          },
+          "required": true
+        },
+        "responses": {
+          "200": {
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/SimilarityResponse"
+                }
+              }
+            },
+            "description": "Sentence Similarity"
+          },
+          "413": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Batch size error",
+                  "error_type": "validation"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Batch size error"
+          },
+          "422": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Tokenization error",
+                  "error_type": "tokenizer"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Tokenization error"
+          },
+          "424": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Inference failed",
+                  "error_type": "backend"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Embedding Error"
+          },
+          "429": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "error": "Model is overloaded",
+                  "error_type": "overloaded"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Model is overloaded"
+          }
+        },
+        "summary": "Get Sentence Similarity. Returns a 424 status code if the model is not an embedding model.",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    },
+    "/tokenize": {
+      "post": {
+        "operationId": "tokenize",
+        "requestBody": {
+          "content": {
+            "application/json": {
+              "schema": {
+                "$ref": "#/components/schemas/TokenizeRequest"
+              }
+            }
+          },
+          "required": true
+        },
+        "responses": {
+          "200": {
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/TokenizeResponse"
+                }
+              }
+            },
+            "description": "Tokenized ids"
+          },
+          "422": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "message": "Tokenization error",
+                  "type": "tokenizer"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/ErrorResponse"
+                }
+              }
+            },
+            "description": "Tokenization error"
+          }
+        },
+        "summary": "Tokenize inputs",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    },
+    "/v1/embeddings": {
+      "post": {
+        "operationId": "openai_embed",
+        "requestBody": {
+          "content": {
+            "application/json": {
+              "schema": {
+                "$ref": "#/components/schemas/OpenAICompatRequest"
+              }
+            }
+          },
+          "required": true
+        },
+        "responses": {
+          "200": {
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/OpenAICompatResponse"
+                }
+              }
+            },
+            "description": "Embeddings"
+          },
+          "413": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "message": "Batch size error",
+                  "type": "validation"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+                }
+              }
+            },
+            "description": "Batch size error"
+          },
+          "422": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "message": "Tokenization error",
+                  "type": "tokenizer"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+                }
+              }
+            },
+            "description": "Tokenization error"
+          },
+          "424": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "message": "Inference failed",
+                  "type": "backend"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+                }
+              }
+            },
+            "description": "Embedding Error"
+          },
+          "429": {
+            "content": {
+              "application/json": {
+                "example": {
+                  "message": "Model is overloaded",
+                  "type": "overloaded"
+                },
+                "schema": {
+                  "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+                }
+              }
+            },
+            "description": "Model is overloaded"
+          }
+        },
+        "summary": "OpenAI compatible route. Returns a 424 status code if the model is not an embedding model.",
+        "tags": [
+          "Text Embeddings Inference"
+        ]
+      }
+    }
+  },
+  "tags": [
+    {
+      "description": "Hugging Face Text Embeddings Inference API",
+      "name": "Text Embeddings Inference"
+    }
+  ]
+}
diff --git a/ads/templates/score_embedding_onnx.jinja2 b/ads/templates/score_embedding_onnx.jinja2
index 8a830f073..9d01a9cd7 100644
--- a/ads/templates/score_embedding_onnx.jinja2
+++ b/ads/templates/score_embedding_onnx.jinja2
@@ -1,4 +1,4 @@
-# score.py 1.0 generated by ADS 2.11.10 on 20241002_212041
+# score.py {{SCORE_VERSION}} generated by ADS {{ADS_VERSION}} on {{time_created}}
 import os
 import sys
 import json
@@ -9,8 +9,8 @@ from jsonschema import validate, ValidationError
 from transformers import AutoTokenizer
 import logging
 
-model_name = 'model.onnx'
-openapi_schema = ''
+model_name = '{{model_file_name}}'
+openapi_schema = 'openapi.json'
 
 
 """
@@ -32,9 +32,9 @@ def load_model(model_file_name=model_name):
         sys.path.insert(0, model_dir)
     contents = os.listdir(model_dir)
     if model_file_name in contents:
-        # print(f'Start loading {model_file_name} from model directory {model_dir} ...')
+        print(f'Start loading {model_file_name} from model directory {model_dir} ...')
         model = ort.InferenceSession(os.path.join(model_dir, model_file_name), providers=['CUDAExecutionProvider','CPUExecutionProvider'])
-        # print("Model is successfully loaded.")
+        print("Model is successfully loaded.")
         return model
     else:
         raise Exception(f'{model_file_name} is not found in model directory {model_dir}')
@@ -43,7 +43,6 @@ def load_model(model_file_name=model_name):
 @lru_cache(maxsize=1)
 def load_tokenizer(model_full_name):
 
-    # todo: do we need model_full_name or have configs in artifact dir?
     model_dir = os.path.dirname(os.path.realpath(__file__))
     # initialize tokenizer
     return AutoTokenizer.from_pretrained(model_dir, clean_up_tokenization_spaces=True)
@@ -83,7 +82,6 @@ def validate_inputs(data):
         # validate the input JSON
         validate(instance=data, schema=request_schema, resolver=resolver)
     except ValidationError as e:
-        # todo: add custom error code and message in error handler
         example_value = {
             "input": ["What are activation functions?"],
             "encoding_format": "float",
diff --git a/tests/unitary/with_extras/model/test_model_framework_embedding_onnx_model.py b/tests/unitary/with_extras/model/test_model_framework_embedding_onnx_model.py
new file mode 100644
index 000000000..571cb298b
--- /dev/null
+++ b/tests/unitary/with_extras/model/test_model_framework_embedding_onnx_model.py
@@ -0,0 +1,132 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2024 Oracle and/or its affiliates.
+# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
+
+import os
+import shutil
+import tempfile
+from unittest.mock import patch
+
+import pytest
+import yaml
+from ads.model.framework.embedding_onnx_model import EmbeddingONNXModel
+from ads.model.model_metadata import Framework
+
+
+class TestEmbeddingONNXModel:
+    def setup_class(cls):
+        cls.tmp_model_dir = tempfile.mkdtemp()
+        os.makedirs(cls.tmp_model_dir, exist_ok=True)
+        cls.inference_conda = "oci://fake_bucket@fake_namespace/inference_conda"
+        cls.training_conda = "oci://fake_bucket@fake_namespace/training_conda"
+
+    def test_init(self):
+        model = EmbeddingONNXModel(artifact_dir=self.tmp_model_dir)
+        assert model.algorithm == "Embedding_ONNX"
+        assert model.framework == Framework.EMBEDDING_ONNX
+
+    @patch("ads.model.generic_model.GenericModel.verify")
+    def test_prepare_and_verify(self, mock_verify):
+        mock_verify.return_value = {"results": "successful"}
+
+        model = EmbeddingONNXModel(artifact_dir=self.tmp_model_dir)
+        model.prepare(
+            model_file_name="test_model_file_name",
+            inference_conda_env=self.inference_conda,
+            inference_python_version="3.8",
+            training_conda_env=self.training_conda,
+            training_python_version="3.8",
+            force_overwrite=True,
+        )
+
+        assert model.model_file_name == "test_model_file_name"
+        artifacts = os.listdir(model.artifact_dir)
+        assert "score.py" in artifacts
+        assert "runtime.yaml" in artifacts
+        assert "openapi.json" in artifacts
+
+        runtime_yaml = os.path.join(model.artifact_dir, "runtime.yaml")
+        with open(runtime_yaml, "r") as f:
+            runtime_dict = yaml.safe_load(f)
+
+            assert (
+                runtime_dict["MODEL_DEPLOYMENT"]["INFERENCE_CONDA_ENV"][
+                    "INFERENCE_ENV_PATH"
+                ]
+                == self.inference_conda
+            )
+            assert (
+                runtime_dict["MODEL_DEPLOYMENT"]["INFERENCE_CONDA_ENV"][
+                    "INFERENCE_PYTHON_VERSION"
+                ]
+                == "3.8"
+            )
+            assert (
+                runtime_dict["MODEL_PROVENANCE"]["TRAINING_CONDA_ENV"][
+                    "TRAINING_ENV_PATH"
+                ]
+                == self.training_conda
+            )
+            assert (
+                runtime_dict["MODEL_PROVENANCE"]["TRAINING_CONDA_ENV"][
+                    "TRAINING_PYTHON_VERSION"
+                ]
+                == "3.8"
+            )
+
+        with pytest.raises(
+            ValueError,
+            match="ADS will not auto serialize `data` for embedding onnx model. Input json serializable `data` and set `auto_serialize_data` as False.",
+        ):
+            model.verify(data="test_data", auto_serialize_data=True)
+
+        model.verify(data="test_data")
+        mock_verify.assert_called_with(
+            data="test_data",
+            reload_artifacts=True,
+            auto_serialize_data=False,
+        )
+
+    @patch("ads.model.generic_model.GenericModel.predict")
+    @patch("ads.model.generic_model.GenericModel.deploy")
+    @patch("ads.model.generic_model.GenericModel.save")
+    def test_prepare_save_deploy_predict(self, mock_save, mock_deploy, mock_predict):
+        model = EmbeddingONNXModel(artifact_dir=self.tmp_model_dir)
+        model.prepare(
+            model_file_name="test_model_file_name",
+            inference_conda_env=self.inference_conda,
+            inference_python_version="3.8",
+            training_conda_env=self.training_conda,
+            training_python_version="3.8",
+            force_overwrite=True,
+        )
+        model.save(display_name="test_embedding_onne_model")
+        model.deploy(
+            display_name="test_embedding_onne_model_deployment",
+            deployment_instance_shape="VM.Standard.E4.Flex",
+            deployment_ocpus=20,
+            deployment_memory_in_gbs=256,
+        )
+
+        with pytest.raises(
+            ValueError,
+            match="ADS will not auto serialize `data` for embedding onnx model. Input json serializable `data` and set `auto_serialize_data` as False.",
+        ):
+            model.verify(data="test_data", auto_serialize_data=True)
+
+        model.predict(data="test_data")
+        mock_predict.assert_called_with(
+            data="test_data",
+            auto_serialize_data=False,
+        )
+        mock_save.assert_called_with(display_name="test_embedding_onne_model")
+        mock_deploy.assert_called_with(
+            display_name="test_embedding_onne_model_deployment",
+            deployment_instance_shape="VM.Standard.E4.Flex",
+            deployment_ocpus=20,
+            deployment_memory_in_gbs=256,
+        )
+
+    def teardown_class(cls):
+        shutil.rmtree(cls.tmp_model_dir, ignore_errors=True)

From 198eac5d2d4c9ca7b583914d094ef771f0869d0c Mon Sep 17 00:00:00 2001
From: Lu Peng <bolu.peng@oracle.com>
Date: Mon, 16 Dec 2024 15:14:41 -0500
Subject: [PATCH 3/7] Updated docs.

---
 .../_template/summary_status.rst              |   4 +-
 .../framework_specific_instruction.rst        |   2 +-
 .../frameworks/embeddingonnxmodel.rst         | 274 ++++++++++++++++++
 3 files changed, 277 insertions(+), 3 deletions(-)
 create mode 100644 docs/source/user_guide/model_registration/frameworks/embeddingonnxmodel.rst

diff --git a/docs/source/user_guide/model_registration/_template/summary_status.rst b/docs/source/user_guide/model_registration/_template/summary_status.rst
index 3cc8759fc..c59dcce73 100644
--- a/docs/source/user_guide/model_registration/_template/summary_status.rst
+++ b/docs/source/user_guide/model_registration/_template/summary_status.rst
@@ -1,3 +1,3 @@
-You can call the ``.summary_status()`` method after a model serialization instance such as ``GenericModel``, ``SklearnModel``, ``TensorFlowModel``, or ``PyTorchModel`` is created. The ``.summary_status()`` method returns a Pandas dataframe that guides you through the entire workflow. It shows which methods are available to call and which ones aren't. Plus it outlines what each method does. If extra actions are required, it also shows those actions.
+You can call the ``.summary_status()`` method after a model serialization instance such as ``GenericModel``, ``SklearnModel``, ``TensorFlowModel``, ``EmbeddingONNXModel``, or ``PyTorchModel`` is created. The ``.summary_status()`` method returns a Pandas dataframe that guides you through the entire workflow. It shows which methods are available to call and which ones aren't. Plus it outlines what each method does. If extra actions are required, it also shows those actions.
 
-The following image displays an example summary status table created after a user initiates a model instance. The table's Step column displays a Status of Done for the initiate step. And the ``Details`` column explains what the initiate step did such as generating a ``score.py`` file. The Step column also displays  the ``prepare()``, ``verify()``, ``save()``, ``deploy()``, and ``predict()`` methods for the model. The Status column displays which method is available next. After the initiate step,  the ``prepare()`` method is available. The next step is to call the ``prepare()`` method.
\ No newline at end of file
+The following image displays an example summary status table created after a user initiates a model instance. The table's Step column displays a Status of Done for the initiate step. And the ``Details`` column explains what the initiate step did such as generating a ``score.py`` file. The Step column also displays  the ``prepare()``, ``verify()``, ``save()``, ``deploy()``, and ``predict()`` methods for the model. The Status column displays which method is available next. After the initiate step,  the ``prepare()`` method is available. The next step is to call the ``prepare()`` method.
diff --git a/docs/source/user_guide/model_registration/framework_specific_instruction.rst b/docs/source/user_guide/model_registration/framework_specific_instruction.rst
index 9150bf2fb..0a5d748a7 100644
--- a/docs/source/user_guide/model_registration/framework_specific_instruction.rst
+++ b/docs/source/user_guide/model_registration/framework_specific_instruction.rst
@@ -10,6 +10,6 @@
     frameworks/lightgbmmodel
     frameworks/xgboostmodel
     frameworks/huggingfacemodel
+    frameworks/embeddingonnxmodel
     frameworks/automlmodel
     frameworks/genericmodel
-
diff --git a/docs/source/user_guide/model_registration/frameworks/embeddingonnxmodel.rst b/docs/source/user_guide/model_registration/frameworks/embeddingonnxmodel.rst
new file mode 100644
index 000000000..9fd22f6b1
--- /dev/null
+++ b/docs/source/user_guide/model_registration/frameworks/embeddingonnxmodel.rst
@@ -0,0 +1,274 @@
+EmbeddingONNXModel
+******************
+
+See `API Documentation <../../../ads.model.framework.html#ads.model.framework.embedding_onnx_model.EmbeddingONNXModel>`__
+
+Overview
+========
+
+The ``ads.model.framework.embedding_onnx_model.EmbeddingONNXModel`` class in ADS is designed to rapidly get an Embedding ONNX Model into production. The ``.prepare()`` method creates the model artifacts that are needed without configuring it or writing code. However, you can customize the required ``score.py`` file.
+
+.. include:: ../_template/overview.rst
+
+The following steps take the `sentence-transformers/all-MiniLM-L6-v2 <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>`_ model and deploy it into production with a few lines of code.
+
+
+**Download Embedding Model from HuggingFace**
+
+.. code-block:: python3
+
+    import tempfile
+    import os
+    import shutil
+    from huggingface_hub import snapshot_download
+
+    local_dir = tempfile.mkdtemp()
+
+    # download files needed for this demostration to local folder
+    snapshot_download(
+        repo_id="sentence-transformers/all-MiniLM-L6-v2",
+        local_dir=local_dir,
+        allow_patterns=[
+            "onnx/model.onnx",
+            "config.json",
+            "special_tokens_map.json",
+            "tokenizer_config.json",
+            "tokenizer.json",
+            "vocab.txt"
+        ]
+    )
+
+    artifact_dir = tempfile.mkdtemp()
+    # copy all downloaded files to artifact folder
+    for root, dirs, files in os.walk(local_dir):
+        for file in files:
+            src_path = os.path.join(root, file)
+            shutil.copy(src_path, artifact_dir)
+
+
+Install Conda Pack
+==================
+
+To deploy the embedding onnx model, start with the onnx conda pack with slug ``onnxruntime_p311_gpu_x86_64``. 
+
+.. code-block:: bash
+
+    odsc conda install -s onnxruntime_p311_gpu_x86_64
+
+
+Prepare Model Artifact
+======================
+
+Instantiate an ``EmbeddingONNXModel()`` object with Embedding ONNX model. All the model related files will be saved under ``artifact_dir``. ADS will auto generate the ``score.py`` and ``runtime.yaml`` that are required for the deployment.
+
+For more detailed information on what parameters that ``EmbeddingONNXModel`` takes, refer to the `API Documentation <../../../ads.model.framework.html#ads.model.framework.embedding_onnx_model.EmbeddingONNXModel>`__
+
+
+.. code-block:: python3
+
+    import ads
+    from ads.model import EmbeddingONNXModel
+
+    # other options are `api_keys` or `security_token` depending on where the code is executed
+    ads.set_auth("resource_principal")
+
+    embedding_onnx_model = EmbeddingONNXModel(artifact_dir=artifact_dir)
+    embedding_onnx_model.prepare(
+        inference_conda_env="onnxruntime_p311_gpu_x86_64",
+        inference_python_version="3.11",
+        model_file_name="model.onnx",
+        force_overwrite=True
+    )
+
+
+Summary Status
+==============
+
+.. include:: ../_template/summary_status.rst
+
+.. figure:: ../figures/summary_status.png
+   :align: center
+
+
+Verify Model
+============
+
+Call the ``verify()`` to check if the model can be executed locally.
+
+.. code-block:: python3
+
+    embedding_onnx_model.verify(
+        {
+            "input": ['What are activation functions?', 'What is Deep Learning?'],
+            "model": "sentence-transformers/all-MiniLM-L6-v2"
+        },
+    )
+
+If successful, similar results as below should be presented.
+
+.. code-block:: python3
+
+    {
+        'object': 'list',
+        'data': 
+            [{
+                'object': 'embedding',
+                'embedding': 
+                    [[
+                        -0.11011122167110443,
+                        -0.39235609769821167,
+                        0.38759472966194153,
+                        -0.34653618931770325,
+                        ...,
+                    ]]
+            }]
+    }
+
+Register Model
+==============
+
+Save the model artifacts and create an model entry in OCI DataScience Model Catalog.
+
+.. code-block:: python3
+
+    embedding_onnx_model.save(display_name="sentence-transformers/all-MiniLM-L6-v2")
+
+
+Deploy and Generate Endpoint
+============================
+
+Create a model deployment from the embedding onnx model in Model Catalog. The process takes several minutes and the deployment configurations will be presented once it's completed.
+
+.. code-block:: python3
+
+    embedding_onnx_model.deploy(
+        display_name="all-MiniLM-L6-v2 Embedding Model Deployment",
+        deployment_log_group_id="<log_group_id>",
+        deployment_access_log_id="<access_log_id>",
+        deployment_predict_log_id="<predict_log_id>",
+        deployment_instance_shape="VM.Standard.E4.Flex",
+        deployment_ocpus=20,
+        deployment_memory_in_gbs=256,
+    )
+
+
+Run Prediction against Endpoint
+===============================
+
+Call ``predict()`` to check the model deployment endpoint. 
+
+.. code-block:: python3
+
+    embedding_onnx_model.predict(
+        {
+            "input": ["What are activation functions?", "What is Deep Learning?"],
+            "model": "sentence-transformers/all-MiniLM-L6-v2"
+        },
+    )
+
+If successful, similar results as below should be presented.
+
+.. code-block:: python3
+
+    {
+        'object': 'list',
+        'data': 
+            [{
+                'object': 'embedding',
+                'embedding': 
+                    [[
+                        -0.11011122167110443,
+                        -0.39235609769821167,
+                        0.38759472966194153,
+                        -0.34653618931770325,
+                        ...,
+                    ]]
+            }]
+    }
+
+Run Prediction with OCI CLI
+===========================
+
+Model deployment endpoints can also be invoked with the OCI CLI.
+
+.. code-block:: bash
+
+    oci raw-request --http-method POST --target-uri <deployment_endpoint> --request-body '{"input": ["What are activation functions?", "What is Deep Learning?"], "model": "sentence-transformers/all-MiniLM-L6-v2"}' --auth resource_principal
+
+
+Example
+=======
+
+.. code-block:: python3
+
+    import tempfile
+    import os
+    import shutil
+    import ads
+    from ads.model import EmbeddingONNXModel
+    from huggingface_hub import snapshot_download
+
+    # other options are `api_keys` or `security_token` depending on where the code is executed
+    ads.set_auth("resource_principal")
+
+    local_dir = tempfile.mkdtemp()
+
+    # download files needed for the demostration to local folder
+    snapshot_download(
+        repo_id="sentence-transformers/all-MiniLM-L6-v2",
+        local_dir=local_dir,
+        allow_patterns=[
+            "onnx/model.onnx",
+            "config.json",
+            "special_tokens_map.json",
+            "tokenizer_config.json",
+            "tokenizer.json",
+            "vocab.txt"
+        ]
+    )
+
+    artifact_dir = tempfile.mkdtemp()
+    # copy all downloaded files to artifact folder
+    for root, dirs, files in os.walk(local_dir):
+        for file in files:
+            src_path = os.path.join(root, file)
+            shutil.copy(src_path, artifact_dir)
+
+    # initialize EmbeddingONNXModel instance and prepare score.py, runtime.yaml and openapi.json files.
+    embedding_onnx_model = EmbeddingONNXModel(artifact_dir=artifact_dir)
+    embedding_onnx_model.prepare(
+        inference_conda_env="onnxruntime_p311_gpu_x86_64",
+        inference_python_version="3.11",
+        model_file_name="model.onnx",
+        force_overwrite=True
+    )
+
+    # validates model locally
+    embedding_onnx_model.verify(
+        {
+            "input": ['What are activation functions?', 'What is Deep Learning?'],
+            "model": "sentence-transformers/all-MiniLM-L6-v2"
+        },
+    )
+
+    # save model to oci model catalog
+    embedding_onnx_model.save(display_name="sentence-transformers/all-MiniLM-L6-v2")
+
+    # deploy model
+    embedding_onnx_model.deploy(
+        display_name="all-MiniLM-L6-v2 Embedding Model Deployment",
+        deployment_log_group_id="<log_group_id>",
+        deployment_access_log_id="<access_log_id>",
+        deployment_predict_log_id="<predict_log_id>",
+        deployment_instance_shape="VM.Standard.E4.Flex",
+        deployment_ocpus=20,
+        deployment_memory_in_gbs=256,
+    )
+
+    # check model deployment endpoint
+    embedding_onnx_model.predict(
+        {
+            "input": ["What are activation functions?", "What is Deep Learning?"],
+            "model": "sentence-transformers/all-MiniLM-L6-v2"
+        },
+    )

From 5fb44e7214fb63422eda9ec3137b28197233c1b3 Mon Sep 17 00:00:00 2001
From: Lu Peng <bolu.peng@oracle.com>
Date: Mon, 16 Dec 2024 15:56:30 -0500
Subject: [PATCH 4/7] Updated pr.

---
 ads/model/framework/embedding_onnx_model.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/ads/model/framework/embedding_onnx_model.py b/ads/model/framework/embedding_onnx_model.py
index 51dace510..94be3ea3f 100644
--- a/ads/model/framework/embedding_onnx_model.py
+++ b/ads/model/framework/embedding_onnx_model.py
@@ -3,7 +3,7 @@
 # Copyright (c) 2024 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
 
-from typing import Dict, Self
+from typing import Dict
 
 from ads.model.extractor.embedding_onnx_extractor import EmbeddingONNXExtractor
 from ads.model.generic_model import FrameworkSpecificModel
@@ -161,7 +161,7 @@ def __init__(
         auth: Dict | None = None,
         serialize: bool = False,
         **kwargs: dict,
-    ) -> Self:
+    ):
         """
         Initiates a EmbeddingONNXModel instance.
 

From e84715c1369e8bc9e15062cf0aa6515f0f10685c Mon Sep 17 00:00:00 2001
From: Lu Peng <bolu.peng@oracle.com>
Date: Mon, 16 Dec 2024 23:04:40 -0500
Subject: [PATCH 5/7] Updated pr.

---
 ads/model/framework/embedding_onnx_model.py   | 50 ++++++++++++------
 ads/templates/score_embedding_onnx.jinja2     | 16 +++++-
 .../frameworks/embeddingonnxmodel.rst         | 52 +++++++++----------
 3 files changed, 74 insertions(+), 44 deletions(-)

diff --git a/ads/model/framework/embedding_onnx_model.py b/ads/model/framework/embedding_onnx_model.py
index 94be3ea3f..316b37599 100644
--- a/ads/model/framework/embedding_onnx_model.py
+++ b/ads/model/framework/embedding_onnx_model.py
@@ -3,7 +3,7 @@
 # Copyright (c) 2024 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
 
-from typing import Dict
+from typing import Dict, Optional
 
 from ads.model.extractor.embedding_onnx_extractor import EmbeddingONNXExtractor
 from ads.model.generic_model import FrameworkSpecificModel
@@ -108,18 +108,26 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
     >>> from huggingface_hub import snapshot_download
 
     >>> local_dir=tempfile.mkdtemp()
-    >>> # download sentence-transformers/all-MiniLM-L6-v2 from huggingface
+    >>> allow_patterns=[
+    ...     "onnx/model.onnx",
+    ...     "config.json",
+    ...     "special_tokens_map.json",
+    ...     "tokenizer_config.json",
+    ...     "tokenizer.json",
+    ...     "vocab.txt"
+    ... ]
+
+    >>> # download files needed for this demostration to local folder
     >>> snapshot_download(
     ...     repo_id="sentence-transformers/all-MiniLM-L6-v2",
-    ...     local_dir=local_dir
+    ...     local_dir=local_dir,
+    ...     allow_patterns=allow_patterns
     ... )
 
-    >>> # copy all files from local_dir to artifact_dir
     >>> artifact_dir = tempfile.mkdtemp()
-    >>> for root, dirs, files in os.walk(local_dir):
-    >>>     for file in files:
-    >>>         src_path = os.path.join(root, file)
-    >>>         shutil.copy(src_path, artifact_dir)
+    >>> # copy all downloaded files to artifact folder
+    >>> for file in allow_patterns:
+    >>>     shutil.copy(local_dir + "/" + file, artifact_dir)
 
     >>> model = EmbeddingONNXModel(artifact_dir=artifact_dir)
     >>> model.summary_status()
@@ -157,8 +165,8 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
 
     def __init__(
         self,
-        artifact_dir: str | None = None,
-        auth: Dict | None = None,
+        artifact_dir: Optional[str] = None,
+        auth: Optional[Dict] = None,
         serialize: bool = False,
         **kwargs: dict,
     ):
@@ -191,18 +199,26 @@ def __init__(
         >>> from huggingface_hub import snapshot_download
 
         >>> local_dir=tempfile.mkdtemp()
-        >>> # download sentence-transformers/all-MiniLM-L6-v2 from huggingface
+        >>> allow_patterns=[
+        ...     "onnx/model.onnx",
+        ...     "config.json",
+        ...     "special_tokens_map.json",
+        ...     "tokenizer_config.json",
+        ...     "tokenizer.json",
+        ...     "vocab.txt"
+        ... ]
+
+        >>> # download files needed for this demostration to local folder
         >>> snapshot_download(
         ...     repo_id="sentence-transformers/all-MiniLM-L6-v2",
-        ...     local_dir=local_dir
+        ...     local_dir=local_dir,
+        ...     allow_patterns=allow_patterns
         ... )
 
-        >>> # copy all files from subdirectory to artifact_dir
         >>> artifact_dir = tempfile.mkdtemp()
-        >>> for root, dirs, files in os.walk(local_dir):
-        >>>     for file in files:
-        >>>         src_path = os.path.join(root, file)
-        >>>         shutil.copy(src_path, artifact_dir)
+        >>> # copy all downloaded files to artifact folder
+        >>> for file in allow_patterns:
+        >>>     shutil.copy(local_dir + "/" + file, artifact_dir)
 
         >>> model = EmbeddingONNXModel(artifact_dir=artifact_dir)
         >>> model.summary_status()
diff --git a/ads/templates/score_embedding_onnx.jinja2 b/ads/templates/score_embedding_onnx.jinja2
index 9d01a9cd7..b7de4ca18 100644
--- a/ads/templates/score_embedding_onnx.jinja2
+++ b/ads/templates/score_embedding_onnx.jinja2
@@ -2,6 +2,7 @@
 import os
 import sys
 import json
+import subprocess
 from functools import lru_cache
 import onnxruntime as ort
 import jsonschema
@@ -33,13 +34,26 @@ def load_model(model_file_name=model_name):
     contents = os.listdir(model_dir)
     if model_file_name in contents:
         print(f'Start loading {model_file_name} from model directory {model_dir} ...')
-        model = ort.InferenceSession(os.path.join(model_dir, model_file_name), providers=['CUDAExecutionProvider','CPUExecutionProvider'])
+        providers= ['CPUExecutionProvider']
+        if is_gpu_available():
+            providers=['CUDAExecutionProvider','CPUExecutionProvider']
+        model = ort.InferenceSession(os.path.join(model_dir, model_file_name), providers=providers)
         print("Model is successfully loaded.")
         return model
     else:
         raise Exception(f'{model_file_name} is not found in model directory {model_dir}')
 
 
+def is_gpu_available():
+    """Check if gpu is available on the infrastructure."""
+    try:
+        result = subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+        if result.returncode == 0:
+           return True
+    except FileNotFoundError:
+        return False
+
+
 @lru_cache(maxsize=1)
 def load_tokenizer(model_full_name):
 
diff --git a/docs/source/user_guide/model_registration/frameworks/embeddingonnxmodel.rst b/docs/source/user_guide/model_registration/frameworks/embeddingonnxmodel.rst
index 9fd22f6b1..ce081bde1 100644
--- a/docs/source/user_guide/model_registration/frameworks/embeddingonnxmodel.rst
+++ b/docs/source/user_guide/model_registration/frameworks/embeddingonnxmodel.rst
@@ -6,7 +6,7 @@ See `API Documentation <../../../ads.model.framework.html#ads.model.framework.em
 Overview
 ========
 
-The ``ads.model.framework.embedding_onnx_model.EmbeddingONNXModel`` class in ADS is designed to rapidly get an Embedding ONNX Model into production. The ``.prepare()`` method creates the model artifacts that are needed without configuring it or writing code. However, you can customize the required ``score.py`` file.
+The ``ads.model.framework.embedding_onnx_model.EmbeddingONNXModel`` class in ADS is designed to rapidly get an Embedding ONNX Model into production. The ``.prepare()`` method creates the model artifacts that are needed without configuring it or writing code. ``EmbeddingONNXModel`` supports `OpenAI spec <https://github.com/huggingface/text-embeddings-inference/blob/main/docs/openapi.json>`_ for embeddings endpoint.
 
 .. include:: ../_template/overview.rst
 
@@ -24,26 +24,26 @@ The following steps take the `sentence-transformers/all-MiniLM-L6-v2 <https://hu
 
     local_dir = tempfile.mkdtemp()
 
+    allow_patterns=[
+        "onnx/model.onnx",
+        "config.json",
+        "special_tokens_map.json",
+        "tokenizer_config.json",
+        "tokenizer.json",
+        "vocab.txt"
+    ]
+
     # download files needed for this demostration to local folder
     snapshot_download(
         repo_id="sentence-transformers/all-MiniLM-L6-v2",
         local_dir=local_dir,
-        allow_patterns=[
-            "onnx/model.onnx",
-            "config.json",
-            "special_tokens_map.json",
-            "tokenizer_config.json",
-            "tokenizer.json",
-            "vocab.txt"
-        ]
+        allow_patterns=allow_patterns
     )
 
     artifact_dir = tempfile.mkdtemp()
     # copy all downloaded files to artifact folder
-    for root, dirs, files in os.walk(local_dir):
-        for file in files:
-            src_path = os.path.join(root, file)
-            shutil.copy(src_path, artifact_dir)
+    for file in allow_patterns:
+        shutil.copy(local_dir + "/" + file, artifact_dir)
 
 
 Install Conda Pack
@@ -213,26 +213,26 @@ Example
 
     local_dir = tempfile.mkdtemp()
 
-    # download files needed for the demostration to local folder
+    allow_patterns=[
+        "onnx/model.onnx",
+        "config.json",
+        "special_tokens_map.json",
+        "tokenizer_config.json",
+        "tokenizer.json",
+        "vocab.txt"
+    ]
+
+    # download files needed for this demostration to local folder
     snapshot_download(
         repo_id="sentence-transformers/all-MiniLM-L6-v2",
         local_dir=local_dir,
-        allow_patterns=[
-            "onnx/model.onnx",
-            "config.json",
-            "special_tokens_map.json",
-            "tokenizer_config.json",
-            "tokenizer.json",
-            "vocab.txt"
-        ]
+        allow_patterns=allow_patterns
     )
 
     artifact_dir = tempfile.mkdtemp()
     # copy all downloaded files to artifact folder
-    for root, dirs, files in os.walk(local_dir):
-        for file in files:
-            src_path = os.path.join(root, file)
-            shutil.copy(src_path, artifact_dir)
+    for file in allow_patterns:
+        shutil.copy(local_dir + "/" + file, artifact_dir)
 
     # initialize EmbeddingONNXModel instance and prepare score.py, runtime.yaml and openapi.json files.
     embedding_onnx_model = EmbeddingONNXModel(artifact_dir=artifact_dir)

From 5bd380cd3c1f6cbfbaf01bf4f3e30c75fc99c502 Mon Sep 17 00:00:00 2001
From: Lu Peng <bolu.peng@oracle.com>
Date: Mon, 6 Jan 2025 11:15:52 -0500
Subject: [PATCH 6/7] Added artifacts validation.

---
 ads/model/__init__.py                         |  2 +-
 ads/model/artifact.py                         |  2 +-
 .../extractor/embedding_onnx_extractor.py     |  2 +-
 ads/model/framework/embedding_onnx_model.py   | 84 ++++++++++++++++++-
 ...st_model_framework_embedding_onnx_model.py | 22 ++++-
 5 files changed, 103 insertions(+), 9 deletions(-)

diff --git a/ads/model/__init__.py b/ads/model/__init__.py
index f0b0febae..6a684c54b 100644
--- a/ads/model/__init__.py
+++ b/ads/model/__init__.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python
 
-# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
+# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
 
 from ads.model.datascience_model import DataScienceModel
diff --git a/ads/model/artifact.py b/ads/model/artifact.py
index 4c116eb78..0d153d0a3 100644
--- a/ads/model/artifact.py
+++ b/ads/model/artifact.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python
 
-# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
+# Copyright (c) 2022, 2025 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
 
 import fnmatch
diff --git a/ads/model/extractor/embedding_onnx_extractor.py b/ads/model/extractor/embedding_onnx_extractor.py
index 9f3f6b463..bca38f62c 100644
--- a/ads/model/extractor/embedding_onnx_extractor.py
+++ b/ads/model/extractor/embedding_onnx_extractor.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python
 
-# Copyright (c) 2024 Oracle and/or its affiliates.
+# Copyright (c) 2025 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
 
 from ads.common.decorator.runtime_dependency import (
diff --git a/ads/model/framework/embedding_onnx_model.py b/ads/model/framework/embedding_onnx_model.py
index 316b37599..2ad11321d 100644
--- a/ads/model/framework/embedding_onnx_model.py
+++ b/ads/model/framework/embedding_onnx_model.py
@@ -1,13 +1,27 @@
 #!/usr/bin/env python
 
-# Copyright (c) 2024 Oracle and/or its affiliates.
+# Copyright (c) 2025 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
 
+import logging
+import os
+from pathlib import Path
 from typing import Dict, Optional
 
 from ads.model.extractor.embedding_onnx_extractor import EmbeddingONNXExtractor
 from ads.model.generic_model import FrameworkSpecificModel
 
+logger = logging.getLogger(__name__)
+
+CONFIG = "config.json"
+TOKENIZERS = [
+    "tokenizer.json",
+    "tokenizer_config.json",
+    "spiece.model",
+    "vocab.txt",
+    "vocab.json",
+]
+
 
 class EmbeddingONNXModel(FrameworkSpecificModel):
     """EmbeddingONNXModel class for embedding onnx model.
@@ -18,6 +32,12 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
         The algorithm of the model.
     artifact_dir: str
         Artifact directory to store the files needed for deployment.
+    model_file_name: str
+        Path to the model artifact.
+    config_json: str
+        Path to the config.json file.
+    tokenizer_dir: str
+        Path to the tokenizer directory.
     auth: Dict
         Default authentication is set using the `ads.set_auth` API. To override the
         default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create
@@ -166,6 +186,9 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
     def __init__(
         self,
         artifact_dir: Optional[str] = None,
+        model_file_name: Optional[str] = None,
+        config_json: Optional[str] = None,
+        tokenizer_dir: Optional[str] = None,
         auth: Optional[Dict] = None,
         serialize: bool = False,
         **kwargs: dict,
@@ -175,8 +198,14 @@ def __init__(
 
         Parameters
         ----------
-        artifact_dir: str
+        artifact_dir: (str, optional). Defaults to None.
             Directory for generate artifact.
+        model_file_name: (str, optional). Defaults to None.
+            Path to the model artifact.
+        config_json: (str, optional). Defaults to None.
+            Path to the config.json file.
+        tokenizer_dir: (str, optional). Defaults to None.
+            Path to the tokenizer directory.
         auth: (Dict, optional). Defaults to None.
             The default authetication is set using `ads.set_auth` API. If you need to override the
             default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
@@ -260,12 +289,63 @@ def __init__(
             **kwargs,
         )
 
+        self._validate_artifact_directory(
+            model_file_name=model_file_name,
+            config_json=config_json,
+            tokenizer_dir=tokenizer_dir,
+        )
+
         self._extractor = EmbeddingONNXExtractor()
         self.framework = self._extractor.framework
         self.algorithm = self._extractor.algorithm
         self.version = self._extractor.version
         self.hyperparameter = self._extractor.hyperparameter
 
+    def _validate_artifact_directory(
+        self,
+        model_file_name: str = None,
+        config_json: str = None,
+        tokenizer_dir: str = None,
+    ):
+        artifacts = []
+        for _, _, files in os.walk(self.artifact_dir):
+            artifacts.extend(files)
+
+        if not artifacts:
+            raise ValueError(
+                f"No files found in {self.artifact_dir}. Specify a valid `artifact_dir`."
+            )
+
+        if not model_file_name:
+            has_model_file = False
+            for artifact in artifacts:
+                if Path(artifact).suffix.lstrip(".").lower() == "onnx":
+                    has_model_file = True
+                    break
+
+            if not has_model_file:
+                raise ValueError(
+                    f"No onnx model found in {self.artifact_dir}. Specify a valid `artifact_dir` or `model_file_name`."
+                )
+
+        if not config_json:
+            if CONFIG not in artifacts:
+                logger.warning(
+                    f"No {CONFIG} found in {self.artifact_dir}. Specify a valid `artifact_dir` or `config_json`."
+                )
+
+        if not tokenizer_dir:
+            has_tokenizer = False
+            for artifact in artifacts:
+                if artifact in TOKENIZERS:
+                    has_tokenizer = True
+                    break
+
+            if not has_tokenizer:
+                logger.warning(
+                    f"No tokenizer found in {self.artifact_dir}. Specify a valid `artifact_dir` or `tokenizer_dir`."
+                )
+
     def verify(
         self, data=None, reload_artifacts=True, auto_serialize_data=False, **kwargs
     ):
diff --git a/tests/unitary/with_extras/model/test_model_framework_embedding_onnx_model.py b/tests/unitary/with_extras/model/test_model_framework_embedding_onnx_model.py
index 571cb298b..152121fb8 100644
--- a/tests/unitary/with_extras/model/test_model_framework_embedding_onnx_model.py
+++ b/tests/unitary/with_extras/model/test_model_framework_embedding_onnx_model.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python
 
-# Copyright (c) 2024 Oracle and/or its affiliates.
+# Copyright (c) 2025 Oracle and/or its affiliates.
 # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
 
 import os
@@ -21,13 +21,20 @@ def setup_class(cls):
         cls.inference_conda = "oci://fake_bucket@fake_namespace/inference_conda"
         cls.training_conda = "oci://fake_bucket@fake_namespace/training_conda"
 
-    def test_init(self):
+    @patch(
+        "ads.model.framework.embedding_onnx_model.EmbeddingONNXModel._validate_artifact_directory"
+    )
+    def test_init(self, mock_validate):
         model = EmbeddingONNXModel(artifact_dir=self.tmp_model_dir)
         assert model.algorithm == "Embedding_ONNX"
         assert model.framework == Framework.EMBEDDING_ONNX
+        mock_validate.assert_called()
 
     @patch("ads.model.generic_model.GenericModel.verify")
-    def test_prepare_and_verify(self, mock_verify):
+    @patch(
+        "ads.model.framework.embedding_onnx_model.EmbeddingONNXModel._validate_artifact_directory"
+    )
+    def test_prepare_and_verify(self, mock_validate, mock_verify):
         mock_verify.return_value = {"results": "successful"}
 
         model = EmbeddingONNXModel(artifact_dir=self.tmp_model_dir)
@@ -87,11 +94,17 @@ def test_prepare_and_verify(self, mock_verify):
             reload_artifacts=True,
             auto_serialize_data=False,
         )
+        mock_validate.assert_called()
 
     @patch("ads.model.generic_model.GenericModel.predict")
     @patch("ads.model.generic_model.GenericModel.deploy")
     @patch("ads.model.generic_model.GenericModel.save")
-    def test_prepare_save_deploy_predict(self, mock_save, mock_deploy, mock_predict):
+    @patch(
+        "ads.model.framework.embedding_onnx_model.EmbeddingONNXModel._validate_artifact_directory"
+    )
+    def test_prepare_save_deploy_predict(
+        self, mock_validate, mock_save, mock_deploy, mock_predict
+    ):
         model = EmbeddingONNXModel(artifact_dir=self.tmp_model_dir)
         model.prepare(
             model_file_name="test_model_file_name",
@@ -127,6 +140,7 @@ def test_prepare_save_deploy_predict(self, mock_save, mock_deploy, mock_predict)
             deployment_ocpus=20,
             deployment_memory_in_gbs=256,
         )
+        mock_validate.assert_called()
 
     def teardown_class(cls):
         shutil.rmtree(cls.tmp_model_dir, ignore_errors=True)

From f848bb104b070c22a01f23b42c20b640ad3e9d15 Mon Sep 17 00:00:00 2001
From: Lu Peng <bolu.peng@oracle.com>
Date: Wed, 22 Jan 2025 09:58:13 -0500
Subject: [PATCH 7/7] Added tokenizers.

---
 ads/model/framework/embedding_onnx_model.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/ads/model/framework/embedding_onnx_model.py b/ads/model/framework/embedding_onnx_model.py
index 2ad11321d..636a0c504 100644
--- a/ads/model/framework/embedding_onnx_model.py
+++ b/ads/model/framework/embedding_onnx_model.py
@@ -16,7 +16,9 @@
 CONFIG = "config.json"
 TOKENIZERS = [
     "tokenizer.json",
+    "tokenizer.model",
     "tokenizer_config.json",
+    "sentencepiece.bpe.model",
     "spiece.model",
     "vocab.txt",
     "vocab.json",