From 4ab4e9af6c25bf9bbf805b6785e363b8cd41cf74 Mon Sep 17 00:00:00 2001 From: Keshav Chandak Date: Tue, 1 Oct 2024 05:25:50 +0530 Subject: [PATCH] Fix: Returning ModelPackage object on register of PipelineModel --- src/sagemaker/pipeline.py | 13 +++++++++- tests/integ/test_inference_pipeline.py | 34 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index b5a3cd4357..3f7aabd3e9 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -26,6 +26,7 @@ ) from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model import ModelPackage from sagemaker.model_card import ( ModelCard, ModelPackageModelCard, @@ -470,7 +471,17 @@ def register( model_card=model_card, ) - self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) + model_package = self.sagemaker_session.create_model_package_from_containers( + **model_pkg_args + ) + + if "ModelPackageArn" in model_package: + return ModelPackage( + role=self.role, + model_package_arn=model_package.get("ModelPackageArn"), + sagemaker_session=self.sagemaker_session, + predictor_cls=self.predictor_cls, + ) def transformer( self, diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index 9e6b41d753..6504932a7e 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -150,6 +150,40 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type): assert "Could not find model" in str(exception.value) +@pytest.mark.release +def test_inference_pipeline_model_register(sagemaker_session): + sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model") + endpoint_name = unique_name_from_base("test-inference-pipeline-deploy") + sparkml_model_data = sagemaker_session.upload_data( + path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"), + key_prefix="integ-test-data/sparkml/model", + ) + + sparkml_model = SparkMLModel( + model_data=sparkml_model_data, + env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA}, + sagemaker_session=sagemaker_session, + ) + + model = PipelineModel( + models=[sparkml_model], + role="SageMakerRole", + sagemaker_session=sagemaker_session, + name=endpoint_name, + ) + model_package_group_name = unique_name_from_base("pipeline-model-package") + model_package = model.register(model_package_group_name=model_package_group_name) + assert model_package.model_package_arn is not None + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package.model_package_arn + ) + + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_package_group_name + ) + + @pytest.mark.slow_test @pytest.mark.flaky(reruns=5, reruns_delay=2) def test_inference_pipeline_model_deploy_and_update_endpoint(