Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
pintaoz-aws committed Sep 13, 2024
1 parent 06dc458 commit 3519265
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 35 deletions.
3 changes: 2 additions & 1 deletion integ/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def test_training_and_inference(self):
content_type="text/csv",
accept="application/csv",
)
assert invoke_result.body.payload_part

assert invoke_result["Body"]["PayloadPart"]

def test_intelligent_defaults(self):
os.environ["SAGEMAKER_CORE_ADMIN_CONFIG_OVERRIDE"] = (
Expand Down
7 changes: 3 additions & 4 deletions src/sagemaker_core/main/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -8389,7 +8389,7 @@ def invoke_with_response_stream(
inference_component_name: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional[InvokeEndpointWithResponseStreamOutput]:
) -> Optional[object]:
"""
Invokes a model at the specified endpoint to return the inference response as a stream.

Expand All @@ -8406,7 +8406,7 @@ def invoke_with_response_stream(
region: Region name.

Returns:
InvokeEndpointWithResponseStreamOutput
object

Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
Expand Down Expand Up @@ -8449,8 +8449,7 @@ def invoke_with_response_stream(
response = client.invoke_endpoint_with_response_stream(**operation_input_args)
logger.debug(f"Response: {response}")

transformed_response = transform(response, "InvokeEndpointWithResponseStreamOutput")
return InvokeEndpointWithResponseStreamOutput(**transformed_response)
return response


class EndpointConfig(Base):
Expand Down
18 changes: 0 additions & 18 deletions src/sagemaker_core/main/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,24 +139,6 @@ class ResponseStream(Base):
internal_stream_failure: Optional[InternalStreamFailure] = Unassigned()


class InvokeEndpointWithResponseStreamOutput(Base):
"""
InvokeEndpointWithResponseStreamOutput

Attributes
----------------------
body
content_type: The MIME type of the inference returned from the model container.
invoked_production_variant: Identifies the production variant that was invoked.
custom_attributes: Provides additional information in the response about the inference returned by a model hosted at an Amazon SageMaker endpoint. The information is an opaque value that is forwarded verbatim. You could use this value, for example, to return an ID received in the CustomAttributes header of a request or other metadata that a service endpoint was programmed to produce. The value must consist of no more than 1024 visible US-ASCII characters as specified in Section 3.3.6. Field Value Components of the Hypertext Transfer Protocol (HTTP/1.1). If the customer wants the custom attribute returned, the model must set the custom attribute to be included on the way back. The code in your model is responsible for setting or updating any custom attributes in the response. If your code does not set this value in the response, an empty value is returned. For example, if a custom attribute represents the trace ID, your model can prepend the custom attribute with Trace ID: in your post-processing function. This feature is currently supported in the Amazon Web Services SDKs but not in the Amazon SageMaker Python SDK.
"""

body: ResponseStream
content_type: Optional[str] = Unassigned()
invoked_production_variant: Optional[str] = Unassigned()
custom_attributes: Optional[str] = Unassigned()


class ModelError(Base):
"""
ModelError
Expand Down
9 changes: 3 additions & 6 deletions src/sagemaker_core/main/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,7 @@ def _serialize_dict(value: Dict) -> dict:
"""
serialized_dict = {}
for k, v in value.items():
serialize_result = serialize(v)
if serialize_result is not None:
if (serialize_result := serialize(v)) is not None:
serialized_dict.update({k: serialize_result})
return serialized_dict

Expand All @@ -518,8 +517,7 @@ def _serialize_list(value: List) -> list:
"""
serialized_list = []
for v in value:
serialize_result = serialize(v)
if serialize_result is not None:
if (serialize_result := serialize(v)) is not None:
serialized_list.append(serialize_result)
return serialized_list

Expand All @@ -536,8 +534,7 @@ def _serialize_shape(value: Any) -> dict:
"""
serialized_dict = {}
for k, v in vars(value).items():
serialize_result = serialize(v)
if serialize_result is not None:
if (serialize_result := serialize(v)) is not None:
key = snake_to_pascal(k) if is_snake_case(k) else k
serialized_dict.update({key[0].upper() + key[1:]: serialize_result})
return serialized_dict
2 changes: 1 addition & 1 deletion src/sagemaker_core/tools/additional_operations.json
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
"operation_name": "InvokeEndpointWithResponseStream",
"resource_name": "Endpoint",
"method_name": "invoke_with_response_stream",
"return_type": "InvokeEndpointWithResponseStreamOutput",
"return_type": "object",
"method_type": "object",
"service_name": "sagemaker-runtime"
}
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker_core/tools/resources_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
REFRESH_METHOD_TEMPLATE,
RESOURCE_BASE_CLASS_TEMPLATE,
RETURN_ITERATOR_TEMPLATE,
RETURN_WITHOUT_DESERIALIZATION_TEMPLATE,
SERIALIZE_INPUT_TEMPLATE,
STOP_METHOD_TEMPLATE,
DELETE_METHOD_TEMPLATE,
Expand Down Expand Up @@ -1373,6 +1374,11 @@ def generate_method(self, method: Method, resource_attributes: list):
return_type = f"Optional[{method.return_type}]"
deserialize_response = DESERIALIZE_RESPONSE_TO_BASIC_TYPE_TEMPLATE
return_string = f"Returns:\n" f" {method.return_type}\n"
elif method.return_type == "object":
# if the return type is object, return the response without deserialization
return_type = f"Optional[{method.return_type}]"
deserialize_response = RETURN_WITHOUT_DESERIALIZATION_TEMPLATE
return_string = f"Returns:\n" f" {method.return_type}\n"
else:
if method.return_type == "cls":
return_type = f'Optional["{method.resource_name}"]'
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker_core/tools/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ def {method_name}(
DESERIALIZE_RESPONSE_TO_BASIC_TYPE_TEMPLATE = """
return list(response.values())[0]"""

RETURN_WITHOUT_DESERIALIZATION_TEMPLATE = """
return response"""

RETURN_ITERATOR_TEMPLATE = """
return ResourceIterator(
{resource_iterator_args}
Expand Down
2 changes: 2 additions & 0 deletions tst/generated/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def test_resources(self, session, mock_transform):
operation_info["return_type"]
]
}
elif operation_info["return_type"] == "object":
return_value = {"return_value": None}
else:
return_cls = self.SHAPE_CLASSES_BY_SHAPE_NAME[
operation_info["return_type"]
Expand Down
9 changes: 4 additions & 5 deletions tst/tools/test_resources_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def invoke_with_response_stream(
inference_component_name: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional[InvokeEndpointWithResponseStreamOutput]:
) -> Optional[object]:
"""
Invokes a model at the specified endpoint to return the inference response as a stream.
Expand All @@ -890,7 +890,7 @@ def invoke_with_response_stream(
region: Region name.
Returns:
InvokeEndpointWithResponseStreamOutput
object
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
Expand Down Expand Up @@ -932,15 +932,14 @@ def invoke_with_response_stream(
response = client.invoke_endpoint_with_response_stream(**operation_input_args)
logger.debug(f"Response: {response}")
transformed_response = transform(response, 'InvokeEndpointWithResponseStreamOutput')
return InvokeEndpointWithResponseStreamOutput(**transformed_response)
return response
'''
method = Method(
**{
"operation_name": "InvokeEndpointWithResponseStream",
"resource_name": "Endpoint",
"method_name": "invoke_with_response_stream",
"return_type": "InvokeEndpointWithResponseStreamOutput",
"return_type": "object",
"method_type": "object",
"service_name": "sagemaker-runtime",
}
Expand Down

0 comments on commit 3519265

Please sign in to comment.