Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ODSC-64654 register model artifact reference #1008

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions ads/model/datascience_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,26 @@ def restore_model(
restore_model_for_hours_specified=restore_model_for_hours_specified,
)

def register_model_artifact_reference(self,bucket_uri_list: List[str]) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, since we already have the .with_artifact() logic, why introduce a new method? Wouldn't it be better to enhance the existing one? It's a well-known method, and users are already familiar with it. Also will download_artifact() method still work?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrDzurb .with_artifact() method already does model-by-ref but using uploadArtifact api. We want to keep that flow too for now, till we completely enable and transition to registerArtifact API. I didn't want to introduce any breaking change, that's why have NOT modified existing method. Moreover, register_model_artifact_reference as a separate method is also consistent with export_model_artifact method, which in a way does similar job.
Yes, the download_artifact method will still work, it will download the json configuration file uploaded by the registerArtifact api.

"""
Registers model artifact references against a model.
Can be used for any model for which model-artifact doesn't exist yet. Requires to provide List of Object
Storage buckets_uri(s) which contain the artifacts.

Parameters
----------
bucket_uri_list: List[str]
The list of OCI Object Storage URIs where model artifacts are present.
Example: [`oci://<bucket_name>@<namespace>/prefix/`, `oci://<bucket_name>@<namespace>/prefix/`].

Returns
-------
None
"""
self.dsc_model.register_model_artifact_reference(
bucket_uri_list=bucket_uri_list
)

def download_artifact(
self,
target_dir: str,
Expand Down
50 changes: 49 additions & 1 deletion ads/model/service/oci_datascience_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ExportModelArtifactDetails,
ImportModelArtifactDetails,
UpdateModelDetails,
WorkRequest,
WorkRequest, RegisterModelArtifactReferenceDetails, OSSModelArtifactReferenceDetails, ModelArtifactReferenceDetails,
)
from oci.exceptions import ServiceError

Expand Down Expand Up @@ -449,6 +449,54 @@ def export_model_artifact(self, bucket_uri: str, region: str = None):
progress_bar_description="Exporting model artifacts."
)

@check_for_model_id(
msg="Model needs to be saved to the Model Catalog before the artifact can be registered against it."
)
def register_model_artifact_reference(self, bucket_uri_list: List[str]) -> None:
"""
Registers model artifact references against a model.
Can be used for any model for which model-artifact doesn't exist yet. Requires to provide List of Object
Storage buckets_uri(s) which contain the artifacts.

Parameters
----------
bucket_uri_list: List[str]
The list of OCI Object Storage URIs where model artifacts are present.
Example: [`oci://<bucket_name>@<namespace>/prefix/`, `oci://<bucket_name>@<namespace>/prefix/`].

Returns
-------
None
"""
model_artifact_reference_details_list = []
for bucket_uri in bucket_uri_list:
bucket_details = ObjectStorageDetails.from_path(bucket_uri)
model_artifact_reference_details = OSSModelArtifactReferenceDetails()
model_artifact_reference_details.namespace = bucket_details.namespace
model_artifact_reference_details.bucket_name = bucket_details.bucket
if bucket_details.filepath is not None and bucket_details.filepath != "":
model_artifact_reference_details.prefix = bucket_details.filepath.strip('/')
model_artifact_reference_details_list.append(model_artifact_reference_details)

register_model_artifact_reference_details = RegisterModelArtifactReferenceDetails()
register_model_artifact_reference_details.model_artifact_references = model_artifact_reference_details_list

work_request_id = self.client.register_model_artifact_reference(
model_id=self.id,
register_model_artifact_reference_details=register_model_artifact_reference_details
).headers["opc-work-request-id"]

# Show progress of model artifact references being registered
try :
DataScienceWorkRequest(work_request_id).wait_work_request(
progress_bar_description="Registering model artifact references."
)
logger.info("Artifact references registered successfully.")
except Exception as ex:
logger.error(f"WorkRequest: `{work_request_id}` failed. Fetching Work Request Error Logs.")
get_work_request_errors_response = self.client.list_work_request_errors(work_request_id)
logger.error(get_work_request_errors_response.data)

@check_for_model_id(
msg="Model needs to be saved to the Model Catalog before it can be updated."
)
Expand Down
43 changes: 42 additions & 1 deletion docs/source/user_guide/model_catalog/model_catalog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1553,4 +1553,45 @@ In the next example, the model that was stored in the model catalog as part of t
Restore Archived Model
**********************

The ``.restore_model()`` method of Model catalog restores the model for a specified number of hours. Restored models can be downloaded for 1-240 hours, defaulting to 24 hours.
The ``.restore_model()`` method of Model catalog restores the model for a specified number of hours. Restored models can be downloaded for 1-240 hours, defaulting to 24 hours.

Register Model Artifact Reference
**********************

The ``.register_model_artifact_reference()`` method of Model catalog registers the references of your OCI Object Storage buckets where the artifact files are present against the model.

By using this API, you can avoid the need to upload or export large model artifacts, and can simply give the references of the OCI Object Storage locations where your artifacts are present. The OCI Data Science will directly read artifact files from those locations when you create a deployment of the model.

The input to this method is a List of bucket_uri(s). The URI syntax for the bucket_uri is:

oci://<bucket_name>@<namespace>/<path>/

Example -

.. code-block:: python3

model.register_model_artifact_reference(
bucket_uri_list = ["oci://<bucket_name>@<namespace>/<path>/"]
)

Important Points:

1. The buckets provided should be of same region and have versioning enabled on them.

2. The <path> is optional. If your files that you want to use for this model are within a path in the bucket, then path can be specified in the bucket_uri, else it can be skipped like below:

oci://<bucket_name>@<namespace>/

3. The location specified by bucket_uri should have at-least one object within it.

4. Make sure that the buckets provided has following IAM policy configured to allow the Data Science service to read artifact files from those Object Storage buckets in your tenancy. An administrator must configure these policies in `IAM <https://docs.oracle.com/iaas/Content/Identity/home1.htm>`_ in the Console.

.. parsed-literal::

allow any-user to read object-family in compartment <compartment> where ALL {target.bucket.name= '<bucket_name>', request.principal.type = /\*datasciencemodel\*/}

If you want, you can have a more granular policy by having an additional filter on project_id like below, which will then give access to the bucket only to models present in the data science project specified in the filter.

.. parsed-literal::

allow any-user to read object-family in compartment <compartment> where ALL {target.bucket.name= '<bucket_name>', request.principal.type = /\*datasciencemodel\*/, request.principal.project_id = '<project_ocid>'}
18 changes: 18 additions & 0 deletions tests/unitary/default_setup/model/test_datascience_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,24 @@ def test_upload_artifact(self):
)
mock_upload.assert_called()

@patch.object(OCIDataScienceModel, 'register_model_artifact_reference')
def test_register_model_artifact_reference(self, mock_register_model_artifact_reference):

# Sample input for the test
bucket_uri_list = [
"oci://bucket1@namespace1/prefix1/",
"oci://bucket2@namespace2/prefix2/"
]

# Call the function with the test data
self.mock_dsc_model.register_model_artifact_reference(bucket_uri_list=bucket_uri_list)

# Assert that the mocked `register_model_artifact_reference` method was called once
# and with the expected arguments
mock_register_model_artifact_reference.assert_called_once_with(
bucket_uri_list=bucket_uri_list
)

def test_download_artifact(self):
"""Tests downloading artifacts from the model catalog."""
# Artifact size greater than 2GB
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ExportModelArtifactDetails,
ImportModelArtifactDetails,
Model,
ModelProvenance,
ModelProvenance, RegisterModelArtifactReferenceDetails, OSSModelArtifactReferenceDetails,
)
from oci.exceptions import ServiceError
from oci.response import Response
Expand Down Expand Up @@ -136,6 +136,12 @@ def setup_class(cls):
headers={"opc-work-request-id": "work_request_id"},
request=None,
)
cls.mock_register_model_artifact_reference_response = Response(
data=None,
status=None,
headers={"opc-work-request-id": "work_request_id"},
request=None,
)

def setup_method(self):
self.mock_model = OCIDataScienceModel(**OCI_MODEL_PAYLOAD)
Expand Down Expand Up @@ -173,6 +179,9 @@ def mock_client(self):
mock_client.export_model_artifact = MagicMock(
return_value=self.mock_export_artifact_response
)
mock_client.register_model_artifact_reference = MagicMock(
return_value=self.mock_register_model_artifact_reference_response
)
return mock_client

def test_create_fail(self):
Expand Down Expand Up @@ -463,6 +472,62 @@ def test_export_model_artifact(
progress_bar_description="Exporting model artifacts."
)

@patch(
"ads.model.service.oci_datascience_model.DataScienceWorkRequest.wait_work_request"
)
@patch("ads.model.service.oci_datascience_model.DataScienceWorkRequest.__init__")
def test_register_model_artifact_reference(
self,
mock_data_science_work_request,
mock_wait_work_request,
mock_client,
):
"""Tests register model artifact reference for a model in model catalog."""
test_bucket_uri_1 = "oci://bucket1@namespace1/prefix1/"
test_bucket_uri_2 = "oci://bucket2@namespace2/prefix2/subPrefix2"
test_bucket_uri_3 = "oci://bucket3@namespace3/"
test_bucket_uri_4 = "oci://bucket4@namespace4"

model_artifact_reference_details_1 = OSSModelArtifactReferenceDetails()
model_artifact_reference_details_1.namespace = 'namespace1'
model_artifact_reference_details_1.bucket_name = 'bucket1'
model_artifact_reference_details_1.prefix = 'prefix1'

model_artifact_reference_details_2 = OSSModelArtifactReferenceDetails()
model_artifact_reference_details_2.namespace = 'namespace2'
model_artifact_reference_details_2.bucket_name = 'bucket2'
model_artifact_reference_details_2.prefix = 'prefix2/subPrefix2'

model_artifact_reference_details_3 = OSSModelArtifactReferenceDetails()
model_artifact_reference_details_3.namespace = 'namespace3'
model_artifact_reference_details_3.bucket_name = 'bucket3'
model_artifact_reference_details_3.prefix = None

model_artifact_reference_details_4 = OSSModelArtifactReferenceDetails()
model_artifact_reference_details_4.namespace = 'namespace4'
model_artifact_reference_details_4.bucket_name = 'bucket4'
model_artifact_reference_details_4.prefix = None

model_artifact_reference_details_list = [model_artifact_reference_details_1, model_artifact_reference_details_2,
model_artifact_reference_details_3, model_artifact_reference_details_4]

register_model_artifact_reference_details = RegisterModelArtifactReferenceDetails()
register_model_artifact_reference_details.model_artifact_references = model_artifact_reference_details_list

mock_data_science_work_request.return_value = None
with patch.object(OCIDataScienceModel, "client", mock_client):
self.mock_model.register_model_artifact_reference(
bucket_uri_list=[test_bucket_uri_1, test_bucket_uri_2, test_bucket_uri_3, test_bucket_uri_4]
)
mock_client.register_model_artifact_reference.assert_called_with(
model_id=self.mock_model.id,
register_model_artifact_reference_details=register_model_artifact_reference_details
)
mock_data_science_work_request.assert_called_with("work_request_id")
mock_wait_work_request.assert_called_with(
progress_bar_description="Registering model artifact references."
)

def test_is_model_by_reference(self):
"""Test to check if model is created by reference using custom metadata information"""

Expand Down
Loading