Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added framework to support embedding onnx model #1027

Merged
merged 13 commits into from
Jan 27, 2025
Merged
24 changes: 11 additions & 13 deletions ads/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from ads.model.generic_model import GenericModel, ModelState
from ads.model.datascience_model import DataScienceModel
from ads.model.model_properties import ModelProperties
from ads.model.deployment.model_deployer import ModelDeployer
from ads.model.deployment.model_deployment import ModelDeployment
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
from ads.model.framework.automl_model import AutoMLModel
from ads.model.framework.embedding_onnx_model import EmbeddingONNXModel
from ads.model.framework.huggingface_model import HuggingFacePipelineModel
from ads.model.framework.lightgbm_model import LightGBMModel
from ads.model.framework.pytorch_model import PyTorchModel
from ads.model.framework.sklearn_model import SklearnModel
from ads.model.framework.spark_model import SparkPipelineModel
from ads.model.framework.tensorflow_model import TensorFlowModel
from ads.model.framework.xgboost_model import XGBoostModel
from ads.model.framework.spark_model import SparkPipelineModel
from ads.model.framework.huggingface_model import HuggingFacePipelineModel

from ads.model.deployment.model_deployer import ModelDeployer
from ads.model.deployment.model_deployment import ModelDeployment
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties

from ads.model.generic_model import GenericModel, ModelState
from ads.model.model_properties import ModelProperties
from ads.model.model_version_set import ModelVersionSet, experiment
from ads.model.serde.common import SERDE
from ads.model.serde.model_input import ModelInputSerializer

from ads.model.model_version_set import ModelVersionSet, experiment
from ads.model.service.oci_datascience_model_version_set import (
ModelVersionSetNotExists,
ModelVersionSetNotSaved,
Expand All @@ -42,6 +39,7 @@
"XGBoostModel",
"SparkPipelineModel",
"HuggingFacePipelineModel",
"EmbeddingONNXModel",
"ModelDeployer",
"ModelDeployment",
"ModelDeploymentProperties",
Expand Down
55 changes: 47 additions & 8 deletions ads/model/artifact.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Copyright (c) 2022, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import fnmatch
import importlib
import os
import sys
import shutil
import sys
import tempfile
import uuid
import fsspec
from datetime import datetime
from typing import Dict, Optional, Tuple

import fsspec
from jinja2 import Environment, PackageLoader

from ads import __version__
from ads.common import auth as authutil
from ads.common import logger, utils
from ads.common.object_storage_details import ObjectStorageDetails
from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
from ads.model.runtime.env_info import EnvInfo, InferenceEnvInfo, TrainingEnvInfo
from ads.model.runtime.runtime_info import RuntimeInfo
from jinja2 import Environment, PackageLoader
import warnings
from ads import __version__
from datetime import datetime

MODEL_ARTIFACT_VERSION = "3.0"
REQUIRED_ARTIFACT_FILES = ("runtime.yaml", "score.py")
Expand Down Expand Up @@ -378,6 +378,45 @@ def prepare_score_py(
) as f:
f.write(scorefn_template.render(context))

def prepare_schema(self, schema_name: str):
"""Copies schema to artifact directory.

Parameters
----------
schema_name: str
The schema name

Returns
-------
None

Raises
------
FileExistsError
If `schema_name` doesn't exist.
"""
uri_src = os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
"templates",
"schemas",
f"{schema_name}",
)

if not os.path.exists(uri_src):
raise FileExistsError(
f"{schema_name} does not exists. "
"Ensure the schema name is valid or specify a different one."
)

uri_dst = os.path.join(self.artifact_dir, os.path.basename(uri_src))

utils.copy_file(
uri_src=uri_src,
uri_dst=uri_dst,
force_overwrite=True,
auth=self.auth,
)

def reload(self):
"""Syncs the `score.py` to reload the model and predict function.

Expand Down
80 changes: 80 additions & 0 deletions ads/model/extractor/embedding_onnx_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python

# Copyright (c) 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from ads.common.decorator.runtime_dependency import (
OptionalDependency,
runtime_dependency,
)
from ads.model.extractor.model_info_extractor import ModelInfoExtractor
from ads.model.model_metadata import Framework


class EmbeddingONNXExtractor(ModelInfoExtractor):
"""Class that extract model metadata from EmbeddingONNXModel models.

Attributes
----------
model: object
The model to extract metadata from.

Methods
-------
framework(self) -> str
Returns the framework of the model.
algorithm(self) -> object
Returns the algorithm of the model.
version(self) -> str
Returns the version of framework of the model.
hyperparameter(self) -> dict
Returns the hyperparameter of the model.
"""

def __init__(self, model=None):
self.model = model

@property
def framework(self):
"""Extracts the framework of the model.

Returns
----------
str:
The framework of the model.
"""
return Framework.EMBEDDING_ONNX

@property
def algorithm(self):
"""Extracts the algorithm of the model.

Returns
----------
object:
The algorithm of the model.
"""
return "Embedding_ONNX"

@property
@runtime_dependency(module="onnxruntime", install_from=OptionalDependency.ONNX)
def version(self):
"""Extracts the framework version of the model.

Returns
----------
str:
The framework version of the model.
"""
return onnxruntime.__version__

@property
def hyperparameter(self):
"""Extracts the hyperparameters of the model.

Returns
----------
dict:
The hyperparameters of the model.
"""
return None
Loading
Loading