Skip to content

Commit

Permalink
Fix: Returning ModelPackage object on register of PipelineModel
Browse files Browse the repository at this point in the history
  • Loading branch information
Keshav Chandak committed Oct 1, 2024
1 parent 66d5fdf commit 4ab4e9a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/sagemaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.model import ModelPackage
from sagemaker.model_card import (
ModelCard,
ModelPackageModelCard,
Expand Down Expand Up @@ -470,7 +471,17 @@ def register(
model_card=model_card,
)

self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)
model_package = self.sagemaker_session.create_model_package_from_containers(
**model_pkg_args
)

if "ModelPackageArn" in model_package:
return ModelPackage(
role=self.role,
model_package_arn=model_package.get("ModelPackageArn"),
sagemaker_session=self.sagemaker_session,
predictor_cls=self.predictor_cls,
)

def transformer(
self,
Expand Down
34 changes: 34 additions & 0 deletions tests/integ/test_inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,40 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
assert "Could not find model" in str(exception.value)


@pytest.mark.release
def test_inference_pipeline_model_register(sagemaker_session):
sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
endpoint_name = unique_name_from_base("test-inference-pipeline-deploy")
sparkml_model_data = sagemaker_session.upload_data(
path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"),
key_prefix="integ-test-data/sparkml/model",
)

sparkml_model = SparkMLModel(
model_data=sparkml_model_data,
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
sagemaker_session=sagemaker_session,
)

model = PipelineModel(
models=[sparkml_model],
role="SageMakerRole",
sagemaker_session=sagemaker_session,
name=endpoint_name,
)
model_package_group_name = unique_name_from_base("pipeline-model-package")
model_package = model.register(model_package_group_name=model_package_group_name)
assert model_package.model_package_arn is not None

sagemaker_session.sagemaker_client.delete_model_package(
ModelPackageName=model_package.model_package_arn
)

sagemaker_session.sagemaker_client.delete_model_package_group(
ModelPackageGroupName=model_package_group_name
)


@pytest.mark.slow_test
@pytest.mark.flaky(reruns=5, reruns_delay=2)
def test_inference_pipeline_model_deploy_and_update_endpoint(
Expand Down

0 comments on commit 4ab4e9a

Please sign in to comment.