Skip to content

Commit

Permalink
ODSC-64654 register model artifact reference
Browse files Browse the repository at this point in the history
ODSC-64654-register-model-artifact-reference
  • Loading branch information
tkg2261 committed Nov 18, 2024
1 parent b15138c commit 298ccc6
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 3 deletions.
20 changes: 20 additions & 0 deletions ads/model/datascience_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,26 @@ def restore_model(
restore_model_for_hours_specified=restore_model_for_hours_specified,
)

def register_model_artifact_reference(self,bucket_uri_list: List[str]) -> None:
"""
Registers model artifact references against a model.
Can be used for any model for which model-artifact doesn't exist yet. Requires to provide List of Object
Storage buckets_uri(s) which contain the artifacts.
Parameters
----------
bucket_uri_list: List[str]
The list of OCI Object Storage URIs where model artifacts are present.
Example: [`oci://<bucket_name>@<namespace>/prefix/`, `oci://<bucket_name>@<namespace>/prefix/`].
Returns
-------
None
"""
self.dsc_model.register_model_artifact_reference(
bucket_uri_list=bucket_uri_list
)

def download_artifact(
self,
target_dir: str,
Expand Down
50 changes: 49 additions & 1 deletion ads/model/service/oci_datascience_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ExportModelArtifactDetails,
ImportModelArtifactDetails,
UpdateModelDetails,
WorkRequest,
WorkRequest, RegisterModelArtifactReferenceDetails, OSSModelArtifactReferenceDetails, ModelArtifactReferenceDetails,
)
from oci.exceptions import ServiceError

Expand Down Expand Up @@ -449,6 +449,54 @@ def export_model_artifact(self, bucket_uri: str, region: str = None):
progress_bar_description="Exporting model artifacts."
)

@check_for_model_id(
msg="Model needs to be saved to the Model Catalog before the artifact can be registered against it."
)
def register_model_artifact_reference(self, bucket_uri_list: List[str]) -> None:
"""
Registers model artifact references against a model.
Can be used for any model for which model-artifact doesn't exist yet. Requires to provide List of Object
Storage buckets_uri(s) which contain the artifacts.
Parameters
----------
bucket_uri_list: List[str]
The list of OCI Object Storage URIs where model artifacts are present.
Example: [`oci://<bucket_name>@<namespace>/prefix/`, `oci://<bucket_name>@<namespace>/prefix/`].
Returns
-------
None
"""
model_artifact_reference_details_list = []
for bucket_uri in bucket_uri_list:
bucket_details = ObjectStorageDetails.from_path(bucket_uri)
model_artifact_reference_details = OSSModelArtifactReferenceDetails()
model_artifact_reference_details.namespace = bucket_details.namespace
model_artifact_reference_details.bucket_name = bucket_details.bucket
if bucket_details.filepath is not None and bucket_details.filepath != "":
model_artifact_reference_details.prefix = bucket_details.filepath.strip('/')
model_artifact_reference_details_list.append(model_artifact_reference_details)

register_model_artifact_reference_details = RegisterModelArtifactReferenceDetails()
register_model_artifact_reference_details.model_artifact_references = model_artifact_reference_details_list

work_request_id = self.client.register_model_artifact_reference(
model_id=self.id,
register_model_artifact_reference_details=register_model_artifact_reference_details
).headers["opc-work-request-id"]

# Show progress of model artifact references being registered
try :
DataScienceWorkRequest(work_request_id).wait_work_request(
progress_bar_description="Registering model artifact references."
)
logger.info("Artifact references registered successfully.")
except Exception as ex:
logger.error(f"WorkRequest: `{work_request_id}` failed. Fetching Work Request Error Logs.")
get_work_request_errors_response = self.client.list_work_request_errors(work_request_id)
logger.error(get_work_request_errors_response.data)

@check_for_model_id(
msg="Model needs to be saved to the Model Catalog before it can be updated."
)
Expand Down
43 changes: 42 additions & 1 deletion docs/source/user_guide/model_catalog/model_catalog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1553,4 +1553,45 @@ In the next example, the model that was stored in the model catalog as part of t
Restore Archived Model
**********************

The ``.restore_model()`` method of Model catalog restores the model for a specified number of hours. Restored models can be downloaded for 1-240 hours, defaulting to 24 hours.
The ``.restore_model()`` method of Model catalog restores the model for a specified number of hours. Restored models can be downloaded for 1-240 hours, defaulting to 24 hours.

Register Model Artifact Reference
**********************

The ``.register_model_artifact_reference()`` method of Model catalog registers the references of your OCI Object Storage buckets where the artifact files are present against the model.

By using this API, you can avoid the need to upload or export large model artifacts, and can simply give the references of the OCI Object Storage locations where your artifacts are present. The OCI Data Science will directly read artifact files from those locations when you create a deployment of the model.

The input to this method is a List of bucket_uri(s). The URI syntax for the bucket_uri is:

oci://<bucket_name>@<namespace>/<path>/

Example -

.. code-block:: python3
model.register_model_artifact_reference(
bucket_uri_list = ["oci://<bucket_name>@<namespace>/<path>/"]
)
Important Points:

1. The buckets provided should be of same region and have versioning enabled on them.

2. The <path> is optional. If your files that you want to use for this model are within a path in the bucket, then path can be specified in the bucket_uri, else it can be skipped like below:

oci://<bucket_name>@<namespace>/

3. The location specified by bucket_uri should have at-least one object within it.

4. Make sure that the buckets provided has following IAM policy configured to allow the Data Science service to read artifact files from those Object Storage buckets in your tenancy. An administrator must configure these policies in `IAM <https://docs.oracle.com/iaas/Content/Identity/home1.htm>`_ in the Console.

.. parsed-literal::
allow any-user to read object-family in compartment <compartment> where ALL {target.bucket.name= '<bucket_name>', request.principal.type = /\*datasciencemodel\*/}
If you want, you can have a more granular policy by having an additional filter on project_id like below, which will then give access to the bucket only to models present in the data science project specified in the filter.

.. parsed-literal::
allow any-user to read object-family in compartment <compartment> where ALL {target.bucket.name= '<bucket_name>', request.principal.type = /\*datasciencemodel\*/, request.principal.project_id = '<project_ocid>'}
18 changes: 18 additions & 0 deletions tests/unitary/default_setup/model/test_datascience_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,24 @@ def test_upload_artifact(self):
)
mock_upload.assert_called()

@patch.object(OCIDataScienceModel, 'register_model_artifact_reference')
def test_register_model_artifact_reference(self, mock_register_model_artifact_reference):

# Sample input for the test
bucket_uri_list = [
"oci://bucket1@namespace1/prefix1/",
"oci://bucket2@namespace2/prefix2/"
]

# Call the function with the test data
self.mock_dsc_model.register_model_artifact_reference(bucket_uri_list=bucket_uri_list)

# Assert that the mocked `register_model_artifact_reference` method was called once
# and with the expected arguments
mock_register_model_artifact_reference.assert_called_once_with(
bucket_uri_list=bucket_uri_list
)

def test_download_artifact(self):
"""Tests downloading artifacts from the model catalog."""
# Artifact size greater than 2GB
Expand Down
67 changes: 66 additions & 1 deletion tests/unitary/default_setup/model/test_oci_datascience_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ExportModelArtifactDetails,
ImportModelArtifactDetails,
Model,
ModelProvenance,
ModelProvenance, RegisterModelArtifactReferenceDetails, OSSModelArtifactReferenceDetails,
)
from oci.exceptions import ServiceError
from oci.response import Response
Expand Down Expand Up @@ -136,6 +136,12 @@ def setup_class(cls):
headers={"opc-work-request-id": "work_request_id"},
request=None,
)
cls.mock_register_model_artifact_reference_response = Response(
data=None,
status=None,
headers={"opc-work-request-id": "work_request_id"},
request=None,
)

def setup_method(self):
self.mock_model = OCIDataScienceModel(**OCI_MODEL_PAYLOAD)
Expand Down Expand Up @@ -173,6 +179,9 @@ def mock_client(self):
mock_client.export_model_artifact = MagicMock(
return_value=self.mock_export_artifact_response
)
mock_client.register_model_artifact_reference = MagicMock(
return_value=self.mock_register_model_artifact_reference_response
)
return mock_client

def test_create_fail(self):
Expand Down Expand Up @@ -463,6 +472,62 @@ def test_export_model_artifact(
progress_bar_description="Exporting model artifacts."
)

@patch(
"ads.model.service.oci_datascience_model.DataScienceWorkRequest.wait_work_request"
)
@patch("ads.model.service.oci_datascience_model.DataScienceWorkRequest.__init__")
def test_register_model_artifact_reference(
self,
mock_data_science_work_request,
mock_wait_work_request,
mock_client,
):
"""Tests register model artifact reference for a model in model catalog."""
test_bucket_uri_1 = "oci://bucket1@namespace1/prefix1/"
test_bucket_uri_2 = "oci://bucket2@namespace2/prefix2/subPrefix2"
test_bucket_uri_3 = "oci://bucket3@namespace3/"
test_bucket_uri_4 = "oci://bucket4@namespace4"

model_artifact_reference_details_1 = OSSModelArtifactReferenceDetails()
model_artifact_reference_details_1.namespace = 'namespace1'
model_artifact_reference_details_1.bucket_name = 'bucket1'
model_artifact_reference_details_1.prefix = 'prefix1'

model_artifact_reference_details_2 = OSSModelArtifactReferenceDetails()
model_artifact_reference_details_2.namespace = 'namespace2'
model_artifact_reference_details_2.bucket_name = 'bucket2'
model_artifact_reference_details_2.prefix = 'prefix2/subPrefix2'

model_artifact_reference_details_3 = OSSModelArtifactReferenceDetails()
model_artifact_reference_details_3.namespace = 'namespace3'
model_artifact_reference_details_3.bucket_name = 'bucket3'
model_artifact_reference_details_3.prefix = None

model_artifact_reference_details_4 = OSSModelArtifactReferenceDetails()
model_artifact_reference_details_4.namespace = 'namespace4'
model_artifact_reference_details_4.bucket_name = 'bucket4'
model_artifact_reference_details_4.prefix = None

model_artifact_reference_details_list = [model_artifact_reference_details_1, model_artifact_reference_details_2,
model_artifact_reference_details_3, model_artifact_reference_details_4]

register_model_artifact_reference_details = RegisterModelArtifactReferenceDetails()
register_model_artifact_reference_details.model_artifact_references = model_artifact_reference_details_list

mock_data_science_work_request.return_value = None
with patch.object(OCIDataScienceModel, "client", mock_client):
self.mock_model.register_model_artifact_reference(
bucket_uri_list=[test_bucket_uri_1, test_bucket_uri_2, test_bucket_uri_3, test_bucket_uri_4]
)
mock_client.register_model_artifact_reference.assert_called_with(
model_id=self.mock_model.id,
register_model_artifact_reference_details=register_model_artifact_reference_details
)
mock_data_science_work_request.assert_called_with("work_request_id")
mock_wait_work_request.assert_called_with(
progress_bar_description="Registering model artifact references."
)

def test_is_model_by_reference(self):
"""Test to check if model is created by reference using custom metadata information"""

Expand Down

0 comments on commit 298ccc6

Please sign in to comment.