From 7042af56ce67abfa4f5cc06d340dd8d0a264ae69 Mon Sep 17 00:00:00 2001 From: targarg Date: Tue, 12 Nov 2024 19:59:47 +0530 Subject: [PATCH 1/2] ODSC-64654 register model artifact reference ODSC-64654-register-model-artifact-reference --- ads/model/datascience_model.py | 20 ++++++ ads/model/service/oci_datascience_model.py | 50 +++++++++++++- .../model_catalog/model_catalog.rst | 43 +++++++++++- .../model/test_datascience_model.py | 18 +++++ .../model/test_oci_datascience_model.py | 67 ++++++++++++++++++- 5 files changed, 195 insertions(+), 3 deletions(-) diff --git a/ads/model/datascience_model.py b/ads/model/datascience_model.py index 4a2f209c4..de2b2e8f8 100644 --- a/ads/model/datascience_model.py +++ b/ads/model/datascience_model.py @@ -1405,6 +1405,26 @@ def restore_model( restore_model_for_hours_specified=restore_model_for_hours_specified, ) + def register_model_artifact_reference(self,bucket_uri_list: List[str]) -> None: + """ + Registers model artifact references against a model. + Can be used for any model for which model-artifact doesn't exist yet. Requires to provide List of Object + Storage buckets_uri(s) which contain the artifacts. + + Parameters + ---------- + bucket_uri_list: List[str] + The list of OCI Object Storage URIs where model artifacts are present. + Example: [`oci://@/prefix/`, `oci://@/prefix/`]. + + Returns + ------- + None + """ + self.dsc_model.register_model_artifact_reference( + bucket_uri_list=bucket_uri_list + ) + def download_artifact( self, target_dir: str, diff --git a/ads/model/service/oci_datascience_model.py b/ads/model/service/oci_datascience_model.py index 44ba091a6..8d2cc61ad 100644 --- a/ads/model/service/oci_datascience_model.py +++ b/ads/model/service/oci_datascience_model.py @@ -26,7 +26,7 @@ ExportModelArtifactDetails, ImportModelArtifactDetails, UpdateModelDetails, - WorkRequest, + WorkRequest, RegisterModelArtifactReferenceDetails, OSSModelArtifactReferenceDetails, ModelArtifactReferenceDetails, ) from oci.exceptions import ServiceError @@ -449,6 +449,54 @@ def export_model_artifact(self, bucket_uri: str, region: str = None): progress_bar_description="Exporting model artifacts." ) + @check_for_model_id( + msg="Model needs to be saved to the Model Catalog before the artifact can be registered against it." + ) + def register_model_artifact_reference(self, bucket_uri_list: List[str]) -> None: + """ + Registers model artifact references against a model. + Can be used for any model for which model-artifact doesn't exist yet. Requires to provide List of Object + Storage buckets_uri(s) which contain the artifacts. + + Parameters + ---------- + bucket_uri_list: List[str] + The list of OCI Object Storage URIs where model artifacts are present. + Example: [`oci://@/prefix/`, `oci://@/prefix/`]. + + Returns + ------- + None + """ + model_artifact_reference_details_list = [] + for bucket_uri in bucket_uri_list: + bucket_details = ObjectStorageDetails.from_path(bucket_uri) + model_artifact_reference_details = OSSModelArtifactReferenceDetails() + model_artifact_reference_details.namespace = bucket_details.namespace + model_artifact_reference_details.bucket_name = bucket_details.bucket + if bucket_details.filepath is not None and bucket_details.filepath != "": + model_artifact_reference_details.prefix = bucket_details.filepath.strip('/') + model_artifact_reference_details_list.append(model_artifact_reference_details) + + register_model_artifact_reference_details = RegisterModelArtifactReferenceDetails() + register_model_artifact_reference_details.model_artifact_references = model_artifact_reference_details_list + + work_request_id = self.client.register_model_artifact_reference( + model_id=self.id, + register_model_artifact_reference_details=register_model_artifact_reference_details + ).headers["opc-work-request-id"] + + # Show progress of model artifact references being registered + try : + DataScienceWorkRequest(work_request_id).wait_work_request( + progress_bar_description="Registering model artifact references." + ) + logger.info("Artifact references registered successfully.") + except Exception as ex: + logger.error(f"WorkRequest: `{work_request_id}` failed. Fetching Work Request Error Logs.") + get_work_request_errors_response = self.client.list_work_request_errors(work_request_id) + logger.error(get_work_request_errors_response.data) + @check_for_model_id( msg="Model needs to be saved to the Model Catalog before it can be updated." ) diff --git a/docs/source/user_guide/model_catalog/model_catalog.rst b/docs/source/user_guide/model_catalog/model_catalog.rst index 4a8732915..a821ecb19 100644 --- a/docs/source/user_guide/model_catalog/model_catalog.rst +++ b/docs/source/user_guide/model_catalog/model_catalog.rst @@ -1553,4 +1553,45 @@ In the next example, the model that was stored in the model catalog as part of t Restore Archived Model ********************** -The ``.restore_model()`` method of Model catalog restores the model for a specified number of hours. Restored models can be downloaded for 1-240 hours, defaulting to 24 hours. \ No newline at end of file +The ``.restore_model()`` method of Model catalog restores the model for a specified number of hours. Restored models can be downloaded for 1-240 hours, defaulting to 24 hours. + +Register Model Artifact Reference +********************** + +The ``.register_model_artifact_reference()`` method of Model catalog registers the references of your OCI Object Storage buckets where the artifact files are present against the model. + +By using this API, you can avoid the need to upload or export large model artifacts, and can simply give the references of the OCI Object Storage locations where your artifacts are present. The OCI Data Science will directly read artifact files from those locations when you create a deployment of the model. + +The input to this method is a List of bucket_uri(s). The URI syntax for the bucket_uri is: + +oci://@// + +Example - + +.. code-block:: python3 + + model.register_model_artifact_reference( + bucket_uri_list = ["oci://@//"] + ) + +Important Points: + +1. The buckets provided should be of same region and have versioning enabled on them. + +2. The is optional. If your files that you want to use for this model are within a path in the bucket, then path can be specified in the bucket_uri, else it can be skipped like below: + + oci://@/ + +3. The location specified by bucket_uri should have at-least one object within it. + +4. Make sure that the buckets provided has following IAM policy configured to allow the Data Science service to read artifact files from those Object Storage buckets in your tenancy. An administrator must configure these policies in `IAM `_ in the Console. + + .. parsed-literal:: + + allow any-user to read object-family in compartment where ALL {target.bucket.name= '', request.principal.type = /\*datasciencemodel\*/} + + If you want, you can have a more granular policy by having an additional filter on project_id like below, which will then give access to the bucket only to models present in the data science project specified in the filter. + + .. parsed-literal:: + + allow any-user to read object-family in compartment where ALL {target.bucket.name= '', request.principal.type = /\*datasciencemodel\*/, request.principal.project_id = ''} \ No newline at end of file diff --git a/tests/unitary/default_setup/model/test_datascience_model.py b/tests/unitary/default_setup/model/test_datascience_model.py index f331576db..8fe07c053 100644 --- a/tests/unitary/default_setup/model/test_datascience_model.py +++ b/tests/unitary/default_setup/model/test_datascience_model.py @@ -817,6 +817,24 @@ def test_upload_artifact(self): ) mock_upload.assert_called() + @patch.object(OCIDataScienceModel, 'register_model_artifact_reference') + def test_register_model_artifact_reference(self, mock_register_model_artifact_reference): + + # Sample input for the test + bucket_uri_list = [ + "oci://bucket1@namespace1/prefix1/", + "oci://bucket2@namespace2/prefix2/" + ] + + # Call the function with the test data + self.mock_dsc_model.register_model_artifact_reference(bucket_uri_list=bucket_uri_list) + + # Assert that the mocked `register_model_artifact_reference` method was called once + # and with the expected arguments + mock_register_model_artifact_reference.assert_called_once_with( + bucket_uri_list=bucket_uri_list + ) + def test_download_artifact(self): """Tests downloading artifacts from the model catalog.""" # Artifact size greater than 2GB diff --git a/tests/unitary/default_setup/model/test_oci_datascience_model.py b/tests/unitary/default_setup/model/test_oci_datascience_model.py index 7a44019e1..24f6e5431 100644 --- a/tests/unitary/default_setup/model/test_oci_datascience_model.py +++ b/tests/unitary/default_setup/model/test_oci_datascience_model.py @@ -12,7 +12,7 @@ ExportModelArtifactDetails, ImportModelArtifactDetails, Model, - ModelProvenance, + ModelProvenance, RegisterModelArtifactReferenceDetails, OSSModelArtifactReferenceDetails, ) from oci.exceptions import ServiceError from oci.response import Response @@ -136,6 +136,12 @@ def setup_class(cls): headers={"opc-work-request-id": "work_request_id"}, request=None, ) + cls.mock_register_model_artifact_reference_response = Response( + data=None, + status=None, + headers={"opc-work-request-id": "work_request_id"}, + request=None, + ) def setup_method(self): self.mock_model = OCIDataScienceModel(**OCI_MODEL_PAYLOAD) @@ -173,6 +179,9 @@ def mock_client(self): mock_client.export_model_artifact = MagicMock( return_value=self.mock_export_artifact_response ) + mock_client.register_model_artifact_reference = MagicMock( + return_value=self.mock_register_model_artifact_reference_response + ) return mock_client def test_create_fail(self): @@ -463,6 +472,62 @@ def test_export_model_artifact( progress_bar_description="Exporting model artifacts." ) + @patch( + "ads.model.service.oci_datascience_model.DataScienceWorkRequest.wait_work_request" + ) + @patch("ads.model.service.oci_datascience_model.DataScienceWorkRequest.__init__") + def test_register_model_artifact_reference( + self, + mock_data_science_work_request, + mock_wait_work_request, + mock_client, + ): + """Tests register model artifact reference for a model in model catalog.""" + test_bucket_uri_1 = "oci://bucket1@namespace1/prefix1/" + test_bucket_uri_2 = "oci://bucket2@namespace2/prefix2/subPrefix2" + test_bucket_uri_3 = "oci://bucket3@namespace3/" + test_bucket_uri_4 = "oci://bucket4@namespace4" + + model_artifact_reference_details_1 = OSSModelArtifactReferenceDetails() + model_artifact_reference_details_1.namespace = 'namespace1' + model_artifact_reference_details_1.bucket_name = 'bucket1' + model_artifact_reference_details_1.prefix = 'prefix1' + + model_artifact_reference_details_2 = OSSModelArtifactReferenceDetails() + model_artifact_reference_details_2.namespace = 'namespace2' + model_artifact_reference_details_2.bucket_name = 'bucket2' + model_artifact_reference_details_2.prefix = 'prefix2/subPrefix2' + + model_artifact_reference_details_3 = OSSModelArtifactReferenceDetails() + model_artifact_reference_details_3.namespace = 'namespace3' + model_artifact_reference_details_3.bucket_name = 'bucket3' + model_artifact_reference_details_3.prefix = None + + model_artifact_reference_details_4 = OSSModelArtifactReferenceDetails() + model_artifact_reference_details_4.namespace = 'namespace4' + model_artifact_reference_details_4.bucket_name = 'bucket4' + model_artifact_reference_details_4.prefix = None + + model_artifact_reference_details_list = [model_artifact_reference_details_1, model_artifact_reference_details_2, + model_artifact_reference_details_3, model_artifact_reference_details_4] + + register_model_artifact_reference_details = RegisterModelArtifactReferenceDetails() + register_model_artifact_reference_details.model_artifact_references = model_artifact_reference_details_list + + mock_data_science_work_request.return_value = None + with patch.object(OCIDataScienceModel, "client", mock_client): + self.mock_model.register_model_artifact_reference( + bucket_uri_list=[test_bucket_uri_1, test_bucket_uri_2, test_bucket_uri_3, test_bucket_uri_4] + ) + mock_client.register_model_artifact_reference.assert_called_with( + model_id=self.mock_model.id, + register_model_artifact_reference_details=register_model_artifact_reference_details + ) + mock_data_science_work_request.assert_called_with("work_request_id") + mock_wait_work_request.assert_called_with( + progress_bar_description="Registering model artifact references." + ) + def test_is_model_by_reference(self): """Test to check if model is created by reference using custom metadata information""" From 9223d68ba3de83c8a0cb7bc78e20739b372aa88f Mon Sep 17 00:00:00 2001 From: targarg Date: Wed, 27 Nov 2024 18:35:58 +0530 Subject: [PATCH 2/2] exiting register_artifact function with exception on failure --- ads/model/service/oci_datascience_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ads/model/service/oci_datascience_model.py b/ads/model/service/oci_datascience_model.py index 8d2cc61ad..4a8757f8c 100644 --- a/ads/model/service/oci_datascience_model.py +++ b/ads/model/service/oci_datascience_model.py @@ -496,6 +496,7 @@ def register_model_artifact_reference(self, bucket_uri_list: List[str]) -> None: logger.error(f"WorkRequest: `{work_request_id}` failed. Fetching Work Request Error Logs.") get_work_request_errors_response = self.client.list_work_request_errors(work_request_id) logger.error(get_work_request_errors_response.data) + raise Exception(get_work_request_errors_response.data) @check_for_model_id( msg="Model needs to be saved to the Model Catalog before it can be updated."