diff --git a/botocore/handlers.py b/botocore/handlers.py index 3c86150781..e9d4dbe4ba 100644 --- a/botocore/handlers.py +++ b/botocore/handlers.py @@ -1269,6 +1269,17 @@ def _update_status_code(response, **kwargs): http_response.status_code = parsed_status_code +def handle_request_validation_mode_member(params, model, **kwargs): + client_config = kwargs.get("context", {}).get("client_config") + if client_config is None: + return + response_checksum_validation = client_config.response_checksum_validation + http_checksum = model.http_checksum + mode_member = http_checksum.get("requestValidationModeMember") + if mode_member and response_checksum_validation == "when_supported": + params.setdefault(mode_member, "ENABLED") + + # This is a list of (event_name, handler). # When a Session is created, everything in this list will be # automatically registered with that Session. @@ -1301,6 +1312,7 @@ def _update_status_code(response, **kwargs): ('before-parse.s3.*', handle_expires_header), ('before-parse.s3.*', _handle_200_error, REGISTER_FIRST), ('before-parameter-build', generate_idempotent_uuid), + ('before-parameter-build', handle_request_validation_mode_member), ('before-parameter-build.s3', validate_bucket_name), ('before-parameter-build.s3', remove_bucket_from_url_paths_from_model), ( diff --git a/tests/functional/test_httpchecksum.py b/tests/functional/test_httpchecksum.py index 07f6691a69..82811846ce 100644 --- a/tests/functional/test_httpchecksum.py +++ b/tests/functional/test_httpchecksum.py @@ -15,6 +15,7 @@ import pytest from botocore.compat import HAS_CRT +from botocore.exceptions import FlexibleChecksumError from tests import ClientHTTPStubber, patch_load_service_model TEST_CHECKSUM_SERVICE_MODEL = { @@ -83,6 +84,8 @@ "Blob": {"type": "blob"}, "SomeStreamingOutput": { "type": "structure", + "members": {"body": {"shape": "Blob", "streaming": True}}, + "payload": "body", }, "SomeStreamingInput": { "type": "structure", @@ -324,3 +327,175 @@ def test_streaming_request_checksum_calculation( read_body = request.body.read() for key, val in expected_trailers.items(): assert f"{key}:{val}".encode() in read_body + + +def _response_checksum_validation_cases(): + response_payload = "Hello world" + cases = [ + ( + "CRC32", + response_payload, + {"x-amz-checksum-crc32": "i9aeUg=="}, + {"kind": "success"}, + ), + ( + "CRC32", + response_payload, + {"x-amz-checksum-crc32": "bm90LWEtY2hlY2tzdW0="}, + {"kind": "failure", "calculatedChecksum": "i9aeUg=="}, + ), + ( + "SHA1", + response_payload, + {"x-amz-checksum-sha1": "e1AsOh9IyGCa4hLN+2Od7jlnP14="}, + {"kind": "success"}, + ), + ( + "SHA1", + response_payload, + {"x-amz-checksum-sha1": "bm90LWEtY2hlY2tzdW0="}, + { + "kind": "failure", + "calculatedChecksum": "e1AsOh9IyGCa4hLN+2Od7jlnP14=", + }, + ), + ( + "SHA256", + response_payload, + { + "x-amz-checksum-sha256": "ZOyIygCyaOW6GjVnihtTFtIS9PNmskdyMlNKiuyjfzw=" + }, + {"kind": "success"}, + ), + ( + "SHA256", + response_payload, + {"x-amz-checksum-sha256": "bm90LWEtY2hlY2tzdW0="}, + { + "kind": "failure", + "calculatedChecksum": "ZOyIygCyaOW6GjVnihtTFtIS9PNmskdyMlNKiuyjfzw=", + }, + ), + ] + if HAS_CRT: + cases.extend( + [ + ( + "CRC32C", + response_payload, + {"x-amz-checksum-crc32c": "crUfeA=="}, + {"kind": "success"}, + ), + ( + "CRC32C", + response_payload, + {"x-amz-checksum-crc32c": "bm90LWEtY2hlY2tzdW0="}, + {"kind": "failure", "calculatedChecksum": "crUfeA=="}, + ), + ( + "CRC64NVME", + response_payload, + {"x-amz-checksum-crc64nvme": "OOJZ0D8xKts="}, + {"kind": "success"}, + ), + ( + "CRC64NVME", + response_payload, + {"x-amz-checksum-crc64nvme": "bm90LWEtY2hlY2tzdW0="}, + {"kind": "failure", "calculatedChecksum": "OOJZ0D8xKts="}, + ), + ] + ) + return cases + + +@pytest.mark.parametrize( + "checksum_algorithm, response_payload, response_headers, expected", + _response_checksum_validation_cases(), +) +def test_response_checksum_validation( + patched_session, + monkeypatch, + checksum_algorithm, + response_payload, + response_headers, + expected, +): + patch_load_service_model( + patched_session, + monkeypatch, + TEST_CHECKSUM_SERVICE_MODEL, + TEST_CHECKSUM_RULESET, + ) + client = patched_session.create_client( + "testservice", + region_name="us-west-2", + ) + with ClientHTTPStubber(client, strict=True) as http_stubber: + http_stubber.add_response( + status=200, + body=response_payload.encode(), + headers=response_headers, + ) + operation_kwargs = { + "body": response_payload, + "checksumAlgorithm": checksum_algorithm, + } + if expected["kind"] == "failure": + with pytest.raises(FlexibleChecksumError) as expected_error: + client.http_checksum_operation(**operation_kwargs) + error_msg = "Expected checksum {} did not match calculated checksum: {}".format( + response_headers[ + f'x-amz-checksum-{checksum_algorithm.lower()}' + ], + expected['calculatedChecksum'], + ) + assert str(expected_error.value) == error_msg + else: + client.http_checksum_operation(**operation_kwargs) + + +@pytest.mark.parametrize( + "checksum_algorithm, response_payload, response_headers, expected", + _response_checksum_validation_cases(), +) +def test_streaming_response_checksum_validation( + patched_session, + monkeypatch, + checksum_algorithm, + response_payload, + response_headers, + expected, +): + patch_load_service_model( + patched_session, + monkeypatch, + TEST_CHECKSUM_SERVICE_MODEL, + TEST_CHECKSUM_RULESET, + ) + client = patched_session.create_client( + "testservice", + region_name="us-west-2", + ) + with ClientHTTPStubber(client, strict=True) as http_stubber: + http_stubber.add_response( + status=200, + body=response_payload.encode(), + headers=response_headers, + ) + response = client.http_checksum_streaming_operation( + body=response_payload, + checksumAlgorithm=checksum_algorithm, + ) + if expected["kind"] == "failure": + with pytest.raises(FlexibleChecksumError) as expected_error: + response["body"].read() + error_msg = "Expected checksum {} did not match calculated checksum: {}".format( + response_headers[ + f'x-amz-checksum-{checksum_algorithm.lower()}' + ], + expected['calculatedChecksum'], + ) + assert str(expected_error.value) == error_msg + else: + response["body"].read() diff --git a/tests/unit/test_handlers.py b/tests/unit/test_handlers.py index 924abed8af..8361b128b6 100644 --- a/tests/unit/test_handlers.py +++ b/tests/unit/test_handlers.py @@ -1944,3 +1944,56 @@ def test_document_response_params_without_expires(document_expires_mocks): mocks['section'].get_section.assert_not_called() mocks['param_section'].add_new_section.assert_not_called() mocks['doc_section'].write.assert_not_called() + + +@pytest.fixture() +def checksum_operation_model(): + operation_model = mock.Mock(spec=OperationModel) + operation_model.http_checksum = { + "requestValidationModeMember": "ChecksumMode", + } + return operation_model + + +def create_checksum_context( + request_checksum_calculation="when_supported", + response_checksum_validation="when_supported", +): + context = { + "client_config": Config( + request_checksum_calculation=request_checksum_calculation, + response_checksum_validation=response_checksum_validation, + ) + } + return context + + +def test_request_validation_mode_member_default(checksum_operation_model): + params = {} + handlers.handle_request_validation_mode_member( + params, checksum_operation_model, context=create_checksum_context() + ) + assert params["ChecksumMode"] == "ENABLED" + + +def test_request_validation_mode_member_when_required( + checksum_operation_model, +): + params = {} + context = create_checksum_context( + response_checksum_validation="when_required" + ) + handlers.handle_request_validation_mode_member( + params, checksum_operation_model, context=context + ) + assert "ChecksumMode" not in params + + +def test_request_validation_mode_member_is_not_enabled( + checksum_operation_model, +): + params = {"ChecksumMode": "FAKE_VALUE"} + handlers.handle_request_validation_mode_member( + params, checksum_operation_model, context=create_checksum_context() + ) + assert params["ChecksumMode"] == "FAKE_VALUE"