Skip to content

Commit

Permalink
change: improve logging and exception messages (#4877)
Browse files Browse the repository at this point in the history
* chore: improve logging and exception messages

* fix: flake8

* chore: address PR comments
  • Loading branch information
evakravi authored Sep 30, 2024
1 parent 92e493c commit 66d5fdf
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 21 deletions.
3 changes: 2 additions & 1 deletion src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def _model_id_retrieval_function(
raise KeyError(error_msg)

error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. "
error_msg += f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. "
error_msg += "Specify a different model ID or try a different AWS Region. "
error_msg += f"For a list of available models, see {MODEL_ID_LIST_WEB_URL}. "

other_model_id_version = None
if model_type == JumpStartModelType.OPEN_WEIGHTS:
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
)

INVALID_MODEL_ID_ERROR_MSG = (
"Invalid model ID: '{model_id}'. Please visit "
f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. "
"Invalid model ID: '{model_id}'. Specify a different model ID or try a different AWS Region. "
f"For a list of available models, see {MODEL_ID_LIST_WEB_URL}. "
"The module `sagemaker.jumpstart.notebook_utils` contains utilities for "
"fetching model IDs. We recommend upgrading to the latest version of sagemaker "
"to get access to the most models."
Expand Down
9 changes: 7 additions & 2 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@
)
from sagemaker.session import Session
from sagemaker.config import load_sagemaker_config
from sagemaker.utils import resolve_value_from_config, TagsDict, get_instance_rate_per_hour
from sagemaker.utils import (
resolve_value_from_config,
TagsDict,
get_instance_rate_per_hour,
get_domain_for_region,
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.user_agent import get_user_agent_extra_suffix

Expand Down Expand Up @@ -553,7 +558,7 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
return (
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). "
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
f"amazonaws.com{'.cn' if region.startswith('cn-') else ''}"
f"{get_domain_for_region(region)}"
f"/{model_specs.hosting_eula_key} for terms of use."
)

Expand Down
18 changes: 18 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string
from sagemaker.workflow.entities import PipelineVariable

ALTERNATE_DOMAINS = {
"cn-north-1": "amazonaws.com.cn",
"cn-northwest-1": "amazonaws.com.cn",
"us-iso-east-1": "c2s.ic.gov",
"us-isob-east-1": "sc2s.sgov.gov",
"us-isof-south-1": "csp.hci.ic.gov",
"us-isof-east-1": "csp.hci.ic.gov",
}

ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
MODEL_PACKAGE_ARN_PATTERN = (
r"arn:aws([a-z\-]*)?:sagemaker:([a-z0-9\-]*):([0-9]{12}):model-package/(.*)"
Expand Down Expand Up @@ -1905,3 +1914,12 @@ def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]:
if len(updated_tags) == 1:
return updated_tags[0]
return updated_tags


def get_domain_for_region(region: str) -> str:
"""Returns the domain for the given region.
Args:
region (str): AWS region name.
"""
return ALTERNATE_DOMAINS.get(region, "amazonaws.com")
10 changes: 2 additions & 8 deletions tests/unit/sagemaker/image_uris/expected_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,8 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

ALTERNATE_DOMAINS = {
"cn-north-1": "amazonaws.com.cn",
"cn-northwest-1": "amazonaws.com.cn",
"us-iso-east-1": "c2s.ic.gov",
"us-isob-east-1": "sc2s.sgov.gov",
"us-isof-south-1": "csp.hci.ic.gov",
"us-isof-east-1": "csp.hci.ic.gov",
}
from sagemaker.utils import ALTERNATE_DOMAINS

DOMAIN = "amazonaws.com"
IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:{}"
MONITOR_URI_FORMAT = "{}.dkr.ecr.{}.{}/sagemaker-model-monitor-analyzer"
Expand Down
22 changes: 14 additions & 8 deletions tests/unit/sagemaker/jumpstart/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,26 +205,31 @@ def test_jumpstart_cache_get_header():
)
assert (
"Unable to find model manifest for 'pytorch-ic-imagenet-inception-v3-classification-4' with "
"version '3.*'. Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html "
"for updated list of models. Consider using model ID 'pytorch-ic-imagenet-inception-v3-"
"version '3.*'. Specify a different model ID or try a different AWS Region. "
"For a list of available models, see "
"https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. "
"Consider using model ID "
"'pytorch-ic-imagenet-inception-v3-"
"classification-4' with version '2.0.0'."
) in str(e.value)

with pytest.raises(KeyError) as e:
cache.get_header(model_id="pytorch-ic-", semantic_version_str="*")
assert (
"Unable to find model manifest for 'pytorch-ic-' with version '*'. "
"Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html "
"for updated list of models. "
"Specify a different model ID or try a different AWS Region. "
"For a list of available models, see "
"https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. "
"Did you mean to use model ID 'pytorch-ic-imagenet-inception-v3-classification-4'?"
) in str(e.value)

with pytest.raises(KeyError) as e:
cache.get_header(model_id="tensorflow-ic-", semantic_version_str="*")
assert (
"Unable to find model manifest for 'tensorflow-ic-' with version '*'. "
"Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html "
"for updated list of models. "
"Specify a different model ID or try a different AWS Region. For a list "
"of available models, see "
"https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. "
"Did you mean to use model ID 'tensorflow-ic-imagenet-inception-"
"v3-classification-4'?"
) in str(e.value)
Expand All @@ -237,8 +242,9 @@ def test_jumpstart_cache_get_header():
)
assert (
"Unable to find model manifest for 'ai21-summarize' with version '1.1.003'. "
"Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html "
"for updated list of models. "
"Specify a different model ID or try a different AWS Region. "
"For a list of available models, see "
"https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. "
"Did you mean to use model ID 'ai21-summarization'?"
) in str(e.value)

Expand Down
18 changes: 18 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,3 +2150,21 @@ def test_has_instance_rate_stat(stats, expected):
def test_deployment_config_response_data(data, expected):
out = utils.deployment_config_response_data(data)
assert out == expected


class TestGetEulaMessage(TestCase):
mock_model_specs = Mock(model_id="some-model-id", hosting_eula_key="some-eula-key")

def test_get_domain_for_region(self):
self.assertEqual(
utils.get_eula_message(self.mock_model_specs, "us-west-2"),
"Model 'some-model-id' requires accepting end-user license agreement (EULA). See"
" https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/some-eula-key "
"for terms of use.",
)
self.assertEqual(
utils.get_eula_message(self.mock_model_specs, "cn-north-1"),
"Model 'some-model-id' requires accepting end-user license agreement (EULA). See"
" https://jumpstart-cache-prod-cn-north-1.s3.cn-north-1.amazonaws.com.cn/some-eula-key "
"for terms of use.",
)
13 changes: 13 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
camel_case_to_pascal_case,
deep_override_dict,
flatten_dict,
get_domain_for_region,
get_instance_type_family,
retry_with_backoff,
check_and_get_run_experiment_config,
Expand Down Expand Up @@ -2231,3 +2232,15 @@ def test_remove_non_existent_tag(self):
def test_remove_only_tag(self):
original_tags = [{"Key": "Tag1", "Value": "Value1"}]
self.assertIsNone(remove_tag_with_key("Tag1", original_tags))


class TestGetDomainForRegion(TestCase):
def test_get_domain_for_region(self):
self.assertEqual(get_domain_for_region("us-west-2"), "amazonaws.com")
self.assertEqual(get_domain_for_region("eu-west-1"), "amazonaws.com")
self.assertEqual(get_domain_for_region("ap-northeast-1"), "amazonaws.com")
self.assertEqual(get_domain_for_region("us-gov-west-1"), "amazonaws.com")
self.assertEqual(get_domain_for_region("cn-northwest-1"), "amazonaws.com.cn")
self.assertEqual(get_domain_for_region("us-iso-east-1"), "c2s.ic.gov")
self.assertEqual(get_domain_for_region("us-isob-east-1"), "sc2s.sgov.gov")
self.assertEqual(get_domain_for_region("invalid-region"), "amazonaws.com")

0 comments on commit 66d5fdf

Please sign in to comment.