Skip to content

Commit

Permalink
Resolve a majority of the request parsing test cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan343 committed Jul 26, 2024
1 parent 8a305de commit 56eb039
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 33 deletions.
106 changes: 81 additions & 25 deletions botocore/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,11 @@ def _expand_host_prefix(self, parameters, operation_model):
return None

host_prefix_expression = operation_endpoint['hostPrefix']
input_members = operation_model.input_shape.members
input_members = (
operation_model.input_shape.members
if operation_model.input_shape
else {}
)
host_labels = [
member
for member, shape in input_members.items()
Expand All @@ -203,6 +207,9 @@ def _expand_host_prefix(self, parameters, operation_model):
)
return host_prefix_expression.format(**format_kwargs)

def _is_shape_flattened(self, shape):
return shape.serialization.get('flattened')


class QuerySerializer(Serializer):
TIMESTAMP_FORMAT = 'iso8601'
Expand Down Expand Up @@ -262,10 +269,6 @@ def _serialize_type_list(self, serialized, value, shape, prefix=''):
return
if self._is_shape_flattened(shape):
list_prefix = prefix
if shape.member.serialization.get('name'):
name = self._get_serialized_name(shape.member, default_name='')
# Replace '.Original' with '.{name}'.
list_prefix = '.'.join(prefix.split('.')[:-1] + [name])
else:
list_name = shape.member.serialization.get('name', 'member')
list_prefix = f'{prefix}.{list_name}'
Expand Down Expand Up @@ -308,9 +311,6 @@ def _serialize_type_boolean(self, serialized, value, shape, prefix=''):
def _default_serialize(self, serialized, value, shape, prefix=''):
serialized[prefix] = value

def _is_shape_flattened(self, shape):
return shape.serialization.get('flattened')


class EC2Serializer(QuerySerializer):
"""EC2 specific customizations to the query protocol serializers.
Expand All @@ -324,7 +324,7 @@ class EC2Serializer(QuerySerializer):

def _get_serialized_name(self, shape, default_name):
# Returns the serialized name for the shape if it exists.
# Otherwise it will return the passed in default_name.
# Otherwise it will return the passed in capitalized default_name.
if 'queryName' in shape.serialization:
return shape.serialization['queryName']
elif 'name' in shape.serialization:
Expand Down Expand Up @@ -420,6 +420,8 @@ def _serialize_type_list(self, serialized, value, shape, key):
list_obj.append(wrapper["__current__"])

def _default_serialize(self, serialized, value, shape, key):
if value is None:
return
serialized[key] = value

def _serialize_type_timestamp(self, serialized, value, shape, key):
Expand Down Expand Up @@ -600,7 +602,19 @@ def _partition_parameters(
location = member.serialization.get('location')
key_name = member.serialization.get('name', param_name)
if location == 'uri':
partitioned['uri_path_kwargs'][key_name] = param_value
if isinstance(param_value, bool):
bool_str = str(param_value).lower()
partitioned['uri_path_kwargs'][key_name] = bool_str
elif member.type_name == 'timestamp':
timestamp_format = member.serialization.get(
'timestampFormat', self.QUERY_STRING_TIMESTAMP_FORMAT
)
timestamp = self._convert_timestamp_to_str(
param_value, timestamp_format
)
partitioned['uri_path_kwargs'][key_name] = timestamp
else:
partitioned['uri_path_kwargs'][key_name] = param_value
elif location == 'querystring':
if isinstance(param_value, dict):
partitioned['query_string_kwargs'].update(param_value)
Expand All @@ -619,7 +633,7 @@ def _partition_parameters(
partitioned['query_string_kwargs'][key_name] = param_value
elif location == 'header':
shape = shape_members[param_name]
if not param_value and shape.type_name == 'list':
if not param_value and shape.type_name in ['list', 'string']:
# Empty lists should not be set on the headers
return
value = self._convert_header_value(shape, param_value)
Expand Down Expand Up @@ -687,7 +701,13 @@ def _serialize_content_type(self, serialized, shape, shape_members):
"""Set Content-Type to application/json for all structured bodies."""
payload = shape.serialization.get('payload')
if self._has_streaming_payload(payload, shape_members):
# Don't apply content-type to streaming bodies
# TODO: Confirm if we should apply content-type to streaming bodies.
if shape_members[payload].type_name == 'string':
serialized['headers']['Content-Type'] = 'text/plain'
elif shape_members[payload].type_name == 'blob':
serialized['headers']['Content-Type'] = (
'application/octet-stream'
)
return

has_body = serialized['body'] != b''
Expand Down Expand Up @@ -722,12 +742,7 @@ def _serialize(self, shape, params, xmlnode, name):
def _serialize_type_structure(self, xmlnode, params, shape, name):
structure_node = ElementTree.SubElement(xmlnode, name)

if 'xmlNamespace' in shape.serialization:
namespace_metadata = shape.serialization['xmlNamespace']
attribute_name = 'xmlns'
if namespace_metadata.get('prefix'):
attribute_name += f":{namespace_metadata['prefix']}"
structure_node.attrib[attribute_name] = namespace_metadata['uri']
self._add_xml_namespace(shape, structure_node)
for key, value in params.items():
member_shape = shape.members[key]
member_name = member_shape.serialization.get('name', key)
Expand All @@ -753,6 +768,7 @@ def _serialize_type_list(self, xmlnode, params, shape, name):
else:
element_name = member_shape.serialization.get('name', 'member')
list_node = ElementTree.SubElement(xmlnode, name)
self._add_xml_namespace(shape, list_node)
for item in params:
self._serialize(member_shape, item, list_node, element_name)

Expand All @@ -765,22 +781,29 @@ def _serialize_type_map(self, xmlnode, params, shape, name):
# <value>val1</value>
# </entry>
# </MyMap>
node = ElementTree.SubElement(xmlnode, name)
# TODO: handle flattened maps.
if not self._is_shape_flattened(shape):
node = ElementTree.SubElement(xmlnode, name)
self._add_xml_namespace(shape, node)

for key, value in params.items():
entry_node = ElementTree.SubElement(node, 'entry')
sub_node = (
ElementTree.SubElement(xmlnode, name)
if self._is_shape_flattened(shape)
else ElementTree.SubElement(node, 'entry')
)
key_name = self._get_serialized_name(shape.key, default_name='key')
val_name = self._get_serialized_name(
shape.value, default_name='value'
)
self._serialize(shape.key, key, entry_node, key_name)
self._serialize(shape.value, value, entry_node, val_name)
self._serialize(shape.key, key, sub_node, key_name)
self._serialize(shape.value, value, sub_node, val_name)

def _serialize_type_boolean(self, xmlnode, params, shape, name):
# For scalar types, the 'params' attr is actually just a scalar
# value representing the data we need to serialize as a boolean.
# It will either be 'true' or 'false'
node = ElementTree.SubElement(xmlnode, name)
self._add_xml_namespace(shape, node)
if params:
str_value = 'true'
else:
Expand All @@ -789,18 +812,51 @@ def _serialize_type_boolean(self, xmlnode, params, shape, name):

def _serialize_type_blob(self, xmlnode, params, shape, name):
node = ElementTree.SubElement(xmlnode, name)
self._add_xml_namespace(shape, node)
node.text = self._get_base64(params)
node.text = self._get_base64(params)

def _serialize_type_timestamp(self, xmlnode, params, shape, name):
node = ElementTree.SubElement(xmlnode, name)
node.text = self._convert_timestamp_to_str(
params, shape.serialization.get('timestampFormat')
self._add_xml_namespace(shape, node)
node.text = str(
self._convert_timestamp_to_str(
params, shape.serialization.get('timestampFormat')
)
)

def _default_serialize(self, xmlnode, params, shape, name):
node = ElementTree.SubElement(xmlnode, name)
self._add_xml_namespace(shape, node)
node.text = str(params)

def _serialize_content_type(self, serialized, shape, shape_members):
"""Set Content-Type to application/json for all structured bodies."""
payload = shape.serialization.get('payload')
if self._has_streaming_payload(payload, shape_members):
# TODO: Confirm if we should apply content-type to streaming bodies.
if shape_members[payload].type_name == 'string':
serialized['headers']['Content-Type'] = 'text/plain'
elif shape_members[payload].type_name == 'blob':
serialized['headers']['Content-Type'] = (
'application/octet-stream'
)
return
serialized['headers']['Content-Type'] = 'application/xml'

def _add_xml_namespace(self, shape, structure_node):
if 'xmlNamespace' in shape.serialization:
namespace_metadata = shape.serialization['xmlNamespace']
attribute_name = 'xmlns'
if isinstance(namespace_metadata, dict):
if namespace_metadata.get('prefix'):
attribute_name += f":{namespace_metadata['prefix']}"
structure_node.attrib[attribute_name] = namespace_metadata[
'uri'
]
elif isinstance(namespace_metadata, str):
structure_node.attrib[attribute_name] = namespace_metadata


SERIALIZERS = {
'ec2': EC2Serializer,
Expand Down
41 changes: 33 additions & 8 deletions tests/unit/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
"""

import copy
import math
import os
import xml.etree.ElementTree as ET
from base64 import b64decode
from enum import Enum

Expand Down Expand Up @@ -97,7 +99,9 @@
'rest-json': RestJSONParser,
'rest-xml': RestXMLParser,
}
PROTOCOL_TEST_BLACKLIST = ['Idempotency token auto fill']
PROTOCOL_TEST_BLACKLIST = [
"Test cases for QueryIdempotencyTokenAutoFill operation",
]


class TestType(Enum):
Expand Down Expand Up @@ -129,7 +133,7 @@ def _compliance_tests(test_type=None):
def test_input_compliance(json_description, case, basename):
service_description = copy.deepcopy(json_description)
service_description['operations'] = {
case.get('name', 'OperationName'): case,
case.get('given', {}).get('name', 'OperationName'): case,
}
model = ServiceModel(service_description)
protocol_type = model.metadata['protocol']
Expand All @@ -145,7 +149,7 @@ def test_input_compliance(json_description, case, basename):
client_endpoint = service_description.get('clientEndpoint')
try:
_assert_request_body_is_bytes(request['body'])
_assert_requests_equal(request, case['serialized'])
_assert_requests_equal(request, case['serialized'], protocol_type)
_assert_endpoints_equal(request, case['serialized'], client_endpoint)
except AssertionError as e:
_input_failure_message(protocol_type, case, request, e)
Expand Down Expand Up @@ -423,11 +427,32 @@ def _serialize_request_description(request_dict):
request_dict['url_path'] += f'&{encoded}'


def _assert_requests_equal(actual, expected):
assert_equal(
actual['body'], expected.get('body', '').encode('utf-8'), 'Body value'
)
def _assert_requests_equal(actual, expected, protocol_type):
expected_body = expected.get('body', '').encode('utf-8')
actual_body = actual['body']
# The expected bodies in our consumed protocol tests have extra
# whitespace and newlines that need to handled. We need to normalize
# the expected and actual response bodies before evaluating equivalence.
try:
if protocol_type in ['json', 'rest-json']:
assert_equal(
json.loads(actual_body),
json.loads(expected_body),
'Body value',
)
elif protocol_type in ['rest-xml']:
tree1 = ET.canonicalize(actual_body, strip_text=True)
tree2 = ET.canonicalize(expected_body, strip_text=True)
assert_equal(tree1, tree2, 'Body value')
else:
assert_equal(actual_body, expected_body, 'Body value')
except (json.JSONDecodeError, ET.ParseError):
assert_equal(actual_body, expected_body, 'Body value')

actual_headers = HeadersDict(actual['headers'])
if protocol_type in ['query', 'ec2']:
if expected.get('headers', {}).get('Content-Type'):
expected['headers']['Content-Type'] += '; charset=utf-8'
expected_headers = HeadersDict(expected.get('headers', {}))
excluded_headers = expected.get('forbidHeaders', [])
_assert_expected_headers_in_request(
Expand Down Expand Up @@ -460,7 +485,7 @@ def _walk_files():


def _load_cases(full_path):
# During developement, you can set the BOTOCORE_TEST_ID
# During development, you can set the BOTOCORE_TEST_ID
# to run a specific test suite or even a specific test case.
# The format is BOTOCORE_TEST_ID=suite_id:test_id or
# BOTOCORE_TEST_ID=suite_id
Expand Down

0 comments on commit 56eb039

Please sign in to comment.