Skip to content

Commit

Permalink
Resolve a majority of the response parsing test cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan343 committed Jul 11, 2024
1 parent 45b329a commit 895d306
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 63 deletions.
112 changes: 71 additions & 41 deletions botocore/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,11 @@ def _has_unknown_tagged_union_member(self, shape, value):
)
raise ResponseParserError(error_msg % shape.name)
tag = self._get_first_key(cleaned_value)
if tag not in shape.members:
serialized_member_names = [
shape.members[member].serialization.get('name', member)
for member in shape.members
]
if tag not in serialized_member_names:
msg = (
"Received a tagged union response with member "
"unknown to client: %s. Please upgrade SDK for full "
Expand Down Expand Up @@ -459,20 +463,9 @@ def _get_error_root(self, original_root):
return original_root

def _member_key_name(self, shape, member_name):
# This method is needed because we have to special case flattened list
# with a serialization name. If this is the case we use the
# locationName from the list's member shape as the key name for the
# surrounding structure.
if shape.type_name == 'list' and shape.serialization.get('flattened'):
list_member_serialized_name = shape.member.serialization.get(
'name'
)
if list_member_serialized_name is not None:
return list_member_serialized_name
serialized_name = shape.serialization.get('name')
if serialized_name is not None:
return serialized_name
return member_name
if shape.type_name == 'structure':
return member_name
return shape.serialization.get('name', member_name)

def _build_name_to_xml_node(self, parent_node):
# If the parent node is actually a list. We should not be trying
Expand Down Expand Up @@ -574,6 +567,8 @@ def _do_modeled_error_parse(self, response, shape):
return self._parse_body_as_xml(response, shape, inject_metadata=False)

def _do_parse(self, response, shape):
if not response.get('body') or shape is None:
return {}
return self._parse_body_as_xml(response, shape, inject_metadata=True)

def _parse_body_as_xml(self, response, shape, inject_metadata=True):
Expand All @@ -586,14 +581,22 @@ def _parse_body_as_xml(self, response, shape, inject_metadata=True):
start = self._find_result_wrapped_shape(
shape.serialization['resultWrapper'], root
)
else:
operation_name = response['context'].get("operation_name", "")
inferred_wrapper_name = operation_name + "Result"
inferred_wrapper = self._find_result_wrapped_shape(
inferred_wrapper_name, root
)
if inferred_wrapper is not None:
start = inferred_wrapper
parsed = self._parse_shape(shape, start)
if inject_metadata:
self._inject_response_metadata(root, parsed)
return parsed

def _find_result_wrapped_shape(self, element_name, xml_root_node):
mapping = self._build_name_to_xml_node(xml_root_node)
return mapping[element_name]
return mapping.get(element_name)

def _inject_response_metadata(self, node, inject_into):
mapping = self._build_name_to_xml_node(node)
Expand Down Expand Up @@ -704,15 +707,24 @@ def _do_error_parse(self, response, shape):

code = body.get('__type', response_code and str(response_code))
if code is not None:
# code has a couple forms as well:
# * "com.aws.dynamodb.vAPI#ProvisionedThroughputExceededException"
# * "ResourceNotFoundException"
if '#' in code:
code = code.rsplit('#', 1)[1]
# The "Code" value can come from either a response
# header or a value in the JSON body.
if 'x-amzn-query-error' in headers:
code = self._do_query_compatible_error_parse(
code, headers, error
)
if 'x-amzn-errortype' in response['headers']:
code = response['headers']['x-amzn-errortype']
# error['Error']['Code'] = code
elif 'code' in body or 'Code' in body:
code = body.get('code', body.get('Code', ''))
# code has a couple forms as well:
# * "com.aws.dynamodb.vAPI#ProvisionedThroughputExceededException"
# * "ResourceNotFoundException"
if ':' in code:
code = code.split(':', 1)[0]
if '#' in code:
code = code.split('#', 1)[1]
error['Error']['Code'] = code
self._inject_response_metadata(error, response['headers'])
return error
Expand Down Expand Up @@ -743,7 +755,16 @@ def _parse_body_as_json(self, body_contents):
return {}
body = body_contents.decode(self.DEFAULT_ENCODING)
try:
original_parsed = json.loads(body)
# Function to remove null values from a JSON object.
def remove_nulls(obj):
if isinstance(obj, dict):
return {k: v for k, v in obj.items() if v is not None}
elif isinstance(obj, list):
return [v for v in obj if v is not None]
else:
return obj

original_parsed = json.loads(body, object_hook=remove_nulls)
return original_parsed
except ValueError:
# if the body cannot be parsed, include
Expand Down Expand Up @@ -994,14 +1015,28 @@ def _handle_string(self, shape, value):
parsed = value
if is_json_value_header(shape):
decoded = base64.b64decode(value).decode(self.DEFAULT_ENCODING)
parsed = json.loads(decoded)
parsed = json.dumps(json.loads(decoded))
return parsed

def _handle_list_header(self, node):
# TODO: Clean up and consider timestamps.
TOKEN_PATTERN = r'[!#$%&\'*+\-.^_`|~\w]+'
QUOTED_STRING_PATTERN = r'"(?:[^"\\]|\\.)*"'
PATTERN = fr'({QUOTED_STRING_PATTERN}|{TOKEN_PATTERN})'
matches = re.findall(PATTERN, node)
parsed_values = []
for match in matches:
if match.startswith('"') and match.endswith('"'):
parsed_values.append(match[1:-1].replace('\\"', '"'))
else:
parsed_values.append(match)
return parsed_values

def _handle_list(self, shape, node):
location = shape.serialization.get('location')
if location == 'header' and not isinstance(node, list):
# List in headers may be a comma separated string as per RFC7230
node = [e.strip() for e in node.split(',')]
node = self._handle_list_header(node)
return super()._handle_list(shape, node)


Expand All @@ -1011,28 +1046,23 @@ class RestJSONParser(BaseRestParser, BaseJSONParser):
def _initial_body_parse(self, body_contents):
return self._parse_body_as_json(body_contents)

def _do_error_parse(self, response, shape):
error = super()._do_error_parse(response, shape)
self._inject_error_code(error, response)
return error

def _inject_error_code(self, error, response):
# The "Code" value can come from either a response
# header or a value in the JSON body.
body = self._initial_body_parse(response['body'])
if 'x-amzn-errortype' in response['headers']:
code = response['headers']['x-amzn-errortype']
# Could be:
# x-amzn-errortype: ValidationException:
code = code.split(':')[0]
error['Error']['Code'] = code
elif 'code' in body or 'Code' in body:
error['Error']['Code'] = body.get('code', body.get('Code', ''))
def _handle_boolean(self, shape, value):
# It's possible to receive a boolean as a string
if isinstance(value, str):
if value == 'true':
return True
else:
return False
return value

def _handle_integer(self, shape, value):
return int(value)

def _handle_float(self, shape, value):
return float(value)

_handle_long = _handle_integer
_handle_double = _handle_float


class RestXMLParser(BaseRestParser, BaseXMLResponseParser):
Expand Down
74 changes: 52 additions & 22 deletions tests/unit/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@
"""
import copy
import math
import os
from base64 import b64decode
from calendar import timegm
from enum import Enum

import pytest
Expand Down Expand Up @@ -180,7 +180,7 @@ def stream(self):
)
def test_output_compliance(json_description, case, basename):
service_description = copy.deepcopy(json_description)
operation_name = case.get('name', 'OperationName')
operation_name = case.get('given', {}).get('name', 'OperationName')
service_description['operations'] = {
operation_name: case,
}
Expand All @@ -193,19 +193,26 @@ def test_output_compliance(json_description, case, basename):
)
# We load the json as utf-8, but the response parser is at the
# botocore boundary, so it expects to work with bytes.
body_bytes = case['response']['body'].encode('utf-8')
case['response']['body'] = body_bytes
# If a test case doesn't define a response body, set it to `None`.
if 'body' in case['response']:
body_bytes = case['response']['body'].encode('utf-8')
case['response']['body'] = body_bytes
else:
case['response']['body'] = None
# We need the headers to be case insensitive
headers = HeadersDict(case['response']['headers'])
case['response']['headers'] = headers
# If a test case doesn't define response headers, set it to an empty `HeadersDict`.
case['response']['headers'] = HeadersDict(
case['response'].get('headers', {})
)
# If this is an event stream fake the raw streamed response
if operation_model.has_event_stream_output:
case['response']['body'] = MockRawResponse(body_bytes)
if 'error' in case:
output_shape = operation_model.output_shape
parsed = parser.parse(case['response'], output_shape)
try:
error_shape = model.shape_for(parsed['Error']['Code'])
error_code = parsed.get("Error", {}).get("Code")
error_shape = model.shape_for_error_code(error_code)
except NoShapeFoundError:
error_shape = None
if error_shape is not None:
Expand Down Expand Up @@ -279,32 +286,55 @@ def _fixup_parsed_result(parsed):
for key in error_keys:
if key not in ['Code', 'Message']:
del parsed['Error'][key]
# 5. Special float types. In the protocol test suite, certain special float
# types are represented as strings: "Infinity", "-Infinity", and "NaN".
# However, we parse these values as actual floats types, so we need to convert
# them back to their string representation.
parsed = _convert_special_floats_to_string(parsed)
return parsed


def _convert_bytes_to_str(parsed):
if isinstance(parsed, dict):
new_dict = {}
for key, value in parsed.items():
new_dict[key] = _convert_bytes_to_str(value)
return new_dict
elif isinstance(parsed, bytes):
return parsed.decode('utf-8')
elif isinstance(parsed, list):
new_list = []
for item in parsed:
new_list.append(_convert_bytes_to_str(item))
return new_list
def _convert(obj, conversion_funcs):
if isinstance(obj, dict):
return {k: _convert(v, conversion_funcs) for k, v in obj.items()}
elif isinstance(obj, list):
return [_convert(item, conversion_funcs) for item in obj]
else:
return parsed
for conv_type, conv_func in conversion_funcs:
if isinstance(obj, conv_type):
return conv_func(obj)
return obj


def _bytes_to_str(value):
if isinstance(value, bytes):
return value.decode('utf-8')
return value


def _convert_bytes_to_str(parsed):
return _convert(parsed, [(bytes, _bytes_to_str)])


def _special_floats_to_str(value):
if isinstance(value, float):
if value in [float('Infinity'), float('-Infinity')] or math.isnan(
value
):
return json.dumps(value)
return value


def _convert_special_floats_to_string(parsed):
return _convert(parsed, [(float, _special_floats_to_str)])


def _compliance_timestamp_parser(value):
datetime = parse_timestamp(value)
# Convert from our time zone to UTC
datetime = datetime.astimezone(tzutc())
# Convert to epoch.
return int(timegm(datetime.timetuple()))
return datetime.timestamp()


def _output_failure_message(
Expand Down

0 comments on commit 895d306

Please sign in to comment.