diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..461199e194 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,20 @@ +Release type: minor + +Starting with this release, any error raised from within schema +extensions will abort the operation and is returned to the client. + +This corresponds to the way we already handle field extension errors +and resolver errors. + +This is particular useful for schema extensions performing checks early +in the request lifecycle, for example: + +```python +class MaxQueryLengthExtension(SchemaExtension): + MAX_QUERY_LENGTH = 8192 + + async def on_operation(self): + if len(self.execution_context.query) > self.MAX_QUERY_LENGTH: + raise StrawberryGraphQLError(message="Query too large") + yield +``` diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index a763f9faa9..af0bd07a7f 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -88,77 +88,82 @@ async def execute( extensions=list(extensions), ) - async with extensions_runner.operation(): - # Note: In graphql-core the schema would be validated here but in - # Strawberry we are validating it at initialisation time instead - if not execution_context.query: - raise MissingQueryError() - - async with extensions_runner.parsing(): - try: - if not execution_context.graphql_document: - execution_context.graphql_document = parse_document( - execution_context.query, **execution_context.parse_options + try: + async with extensions_runner.operation(): + # Note: In graphql-core the schema would be validated here but in + # Strawberry we are validating it at initialisation time instead + if not execution_context.query: + raise MissingQueryError() + + async with extensions_runner.parsing(): + try: + if not execution_context.graphql_document: + execution_context.graphql_document = parse_document( + execution_context.query, **execution_context.parse_options + ) + + except GraphQLError as exc: + execution_context.errors = [exc] + process_errors([exc], execution_context) + return ExecutionResult( + data=None, + errors=[exc], + extensions=await extensions_runner.get_extensions_results(), ) - except GraphQLError as error: - execution_context.errors = [error] - process_errors([error], execution_context) - return ExecutionResult( - data=None, - errors=[error], - extensions=await extensions_runner.get_extensions_results(), - ) - - except Exception as error: # pragma: no cover - error = GraphQLError(str(error), original_error=error) - - execution_context.errors = [error] - process_errors([error], execution_context) - - return ExecutionResult( - data=None, - errors=[error], - extensions=await extensions_runner.get_extensions_results(), - ) - - if execution_context.operation_type not in allowed_operation_types: - raise InvalidOperationTypeError(execution_context.operation_type) - - async with extensions_runner.validation(): - _run_validation(execution_context) - if execution_context.errors: - process_errors(execution_context.errors, execution_context) - return ExecutionResult(data=None, errors=execution_context.errors) - - async with extensions_runner.executing(): - if not execution_context.result: - result = original_execute( - schema, - execution_context.graphql_document, - root_value=execution_context.root_value, - middleware=extensions_runner.as_middleware_manager(), - variable_values=execution_context.variables, - operation_name=execution_context.operation_name, - context_value=execution_context.context, - execution_context_class=execution_context_class, - ) - - if isawaitable(result): - result = await cast(Awaitable["GraphQLExecutionResult"], result) - - result = cast("GraphQLExecutionResult", result) - execution_context.result = result - # Also set errors on the execution_context so that it's easier - # to access in extensions - if result.errors: - execution_context.errors = result.errors - - # Run the `Schema.process_errors` function here before - # extensions have a chance to modify them (see the MaskErrors - # extension). That way we can log the original errors but - # only return a sanitised version to the client. - process_errors(result.errors, execution_context) + if execution_context.operation_type not in allowed_operation_types: + raise InvalidOperationTypeError(execution_context.operation_type) + + async with extensions_runner.validation(): + _run_validation(execution_context) + if execution_context.errors: + process_errors(execution_context.errors, execution_context) + return ExecutionResult(data=None, errors=execution_context.errors) + + async with extensions_runner.executing(): + if not execution_context.result: + result = original_execute( + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + middleware=extensions_runner.as_middleware_manager(), + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + context_value=execution_context.context, + execution_context_class=execution_context_class, + ) + + if isawaitable(result): + result = await cast(Awaitable["GraphQLExecutionResult"], result) + + result = cast("GraphQLExecutionResult", result) + execution_context.result = result + # Also set errors on the execution_context so that it's easier + # to access in extensions + if result.errors: + execution_context.errors = result.errors + + # Run the `Schema.process_errors` function here before + # extensions have a chance to modify them (see the MaskErrors + # extension). That way we can log the original errors but + # only return a sanitised version to the client. + process_errors(result.errors, execution_context) + + except (MissingQueryError, InvalidOperationTypeError) as e: + raise e + except Exception as exc: + error = ( + exc + if isinstance(exc, GraphQLError) + else GraphQLError(str(exc), original_error=exc) + ) + execution_context.errors = [error] + process_errors([error], execution_context) + return ExecutionResult( + data=None, + errors=[error], + extensions=await extensions_runner.get_extensions_results(), + ) return ExecutionResult( data=execution_context.result.data, @@ -181,80 +186,86 @@ def execute_sync( extensions=list(extensions), ) - with extensions_runner.operation(): - # Note: In graphql-core the schema would be validated here but in - # Strawberry we are validating it at initialisation time instead - if not execution_context.query: - raise MissingQueryError() - - with extensions_runner.parsing(): - try: - if not execution_context.graphql_document: - execution_context.graphql_document = parse_document( - execution_context.query, **execution_context.parse_options + try: + with extensions_runner.operation(): + # Note: In graphql-core the schema would be validated here but in + # Strawberry we are validating it at initialisation time instead + if not execution_context.query: + raise MissingQueryError() + + with extensions_runner.parsing(): + try: + if not execution_context.graphql_document: + execution_context.graphql_document = parse_document( + execution_context.query, **execution_context.parse_options + ) + + except GraphQLError as exc: + execution_context.errors = [exc] + process_errors([exc], execution_context) + return ExecutionResult( + data=None, + errors=[exc], + extensions=extensions_runner.get_extensions_results_sync(), ) - except GraphQLError as error: - execution_context.errors = [error] - process_errors([error], execution_context) - return ExecutionResult( - data=None, - errors=[error], - extensions=extensions_runner.get_extensions_results_sync(), - ) - - except Exception as error: # pragma: no cover - error = GraphQLError(str(error), original_error=error) - - execution_context.errors = [error] - process_errors([error], execution_context) - return ExecutionResult( - data=None, - errors=[error], - extensions=extensions_runner.get_extensions_results_sync(), - ) - - if execution_context.operation_type not in allowed_operation_types: - raise InvalidOperationTypeError(execution_context.operation_type) - - with extensions_runner.validation(): - _run_validation(execution_context) - if execution_context.errors: - process_errors(execution_context.errors, execution_context) - return ExecutionResult(data=None, errors=execution_context.errors) - - with extensions_runner.executing(): - if not execution_context.result: - result = original_execute( - schema, - execution_context.graphql_document, - root_value=execution_context.root_value, - middleware=extensions_runner.as_middleware_manager(), - variable_values=execution_context.variables, - operation_name=execution_context.operation_name, - context_value=execution_context.context, - execution_context_class=execution_context_class, - ) - - if isawaitable(result): - result = cast(Awaitable["GraphQLExecutionResult"], result) - ensure_future(result).cancel() - raise RuntimeError( - "GraphQL execution failed to complete synchronously." + if execution_context.operation_type not in allowed_operation_types: + raise InvalidOperationTypeError(execution_context.operation_type) + + with extensions_runner.validation(): + _run_validation(execution_context) + if execution_context.errors: + process_errors(execution_context.errors, execution_context) + return ExecutionResult(data=None, errors=execution_context.errors) + + with extensions_runner.executing(): + if not execution_context.result: + result = original_execute( + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + middleware=extensions_runner.as_middleware_manager(), + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + context_value=execution_context.context, + execution_context_class=execution_context_class, ) - result = cast("GraphQLExecutionResult", result) - execution_context.result = result - # Also set errors on the execution_context so that it's easier - # to access in extensions - if result.errors: - execution_context.errors = result.errors - - # Run the `Schema.process_errors` function here before - # extensions have a chance to modify them (see the MaskErrors - # extension). That way we can log the original errors but - # only return a sanitised version to the client. - process_errors(result.errors, execution_context) + if isawaitable(result): + result = cast(Awaitable["GraphQLExecutionResult"], result) + ensure_future(result).cancel() + raise RuntimeError( + "GraphQL execution failed to complete synchronously." + ) + + result = cast("GraphQLExecutionResult", result) + execution_context.result = result + # Also set errors on the execution_context so that it's easier + # to access in extensions + if result.errors: + execution_context.errors = result.errors + + # Run the `Schema.process_errors` function here before + # extensions have a chance to modify them (see the MaskErrors + # extension). That way we can log the original errors but + # only return a sanitised version to the client. + process_errors(result.errors, execution_context) + + except (MissingQueryError, InvalidOperationTypeError) as e: + raise e + except Exception as exc: + error = ( + exc + if isinstance(exc, GraphQLError) + else GraphQLError(str(exc), original_error=exc) + ) + execution_context.errors = [error] + process_errors([error], execution_context) + return ExecutionResult( + data=None, + errors=[error], + extensions=extensions_runner.get_extensions_results_sync(), + ) return ExecutionResult( data=execution_context.result.data, diff --git a/tests/schema/extensions/test_extensions.py b/tests/schema/extensions/test_extensions.py index 555d83063c..9b527289fa 100644 --- a/tests/schema/extensions/test_extensions.py +++ b/tests/schema/extensions/test_extensions.py @@ -467,8 +467,10 @@ def on_executing_start(self): schema = strawberry.Schema( query=default_query_types_and_query.query_type, extensions=[WrongUsageExtension] ) - with pytest.raises(ValueError): - schema.execute_sync(default_query_types_and_query.query) + + result = schema.execute_sync(default_query_types_and_query.query) + assert len(result.errors) == 1 + assert isinstance(result.errors[0].original_error, ValueError) async def test_legacy_extension_supported(): @@ -628,8 +630,184 @@ def string(self) -> str: schema.execute_sync(query) +class ExceptionTestingExtension(SchemaExtension): + def __init__(self, failing_hook: str): + self.failing_hook = failing_hook + self.called_hooks = set() + + def on_operation(self): + if self.failing_hook == "on_operation_start": + raise Exception(self.failing_hook) + self.called_hooks.add(1) + + with contextlib.suppress(Exception): + yield + + if self.failing_hook == "on_operation_end": + raise Exception(self.failing_hook) + self.called_hooks.add(8) + + def on_parse(self): + if self.failing_hook == "on_parse_start": + raise Exception(self.failing_hook) + self.called_hooks.add(2) + + with contextlib.suppress(Exception): + yield + + if self.failing_hook == "on_parse_end": + raise Exception(self.failing_hook) + self.called_hooks.add(3) + + def on_validate(self): + if self.failing_hook == "on_validate_start": + raise Exception(self.failing_hook) + self.called_hooks.add(4) + + with contextlib.suppress(Exception): + yield + + if self.failing_hook == "on_validate_end": + raise Exception(self.failing_hook) + self.called_hooks.add(5) + + def on_execute(self): + if self.failing_hook == "on_execute_start": + raise Exception(self.failing_hook) + self.called_hooks.add(6) + + with contextlib.suppress(Exception): + yield + + if self.failing_hook == "on_execute_end": + raise Exception(self.failing_hook) + self.called_hooks.add(7) + + +@pytest.mark.parametrize( + "failing_hook", + ( + "on_operation_start", + "on_operation_end", + "on_parse_start", + "on_parse_end", + "on_validate_start", + "on_validate_end", + "on_execute_start", + "on_execute_end", + ), +) +@pytest.mark.asyncio +async def test_exceptions_are_included_in_the_execution_result(failing_hook): + @strawberry.type + class Query: + @strawberry.field + def ping(self) -> str: + return "pong" + + schema = strawberry.Schema( + query=Query, + extensions=[ExceptionTestingExtension(failing_hook)], + ) + document = "query { ping }" + + sync_result = schema.execute_sync(document) + assert sync_result.errors is not None + assert len(sync_result.errors) == 1 + assert sync_result.errors[0].message == failing_hook + + async_result = await schema.execute(document) + assert async_result.errors is not None + assert len(async_result.errors) == 1 + assert sync_result.errors[0].message == failing_hook + + +@pytest.mark.parametrize( + ("failing_hook", "expected_hooks"), + ( + ("on_operation_start", set()), + ("on_parse_start", {1, 8}), + ("on_parse_end", {1, 2, 8}), + ("on_validate_start", {1, 2, 3, 8}), + ("on_validate_end", {1, 2, 3, 4, 8}), + ("on_execute_start", {1, 2, 3, 4, 5, 8}), + ("on_execute_end", {1, 2, 3, 4, 5, 6, 8}), + ("on_operation_end", {1, 2, 3, 4, 5, 6, 7}), + ), +) +@pytest.mark.asyncio +async def test_exceptions_abort_evaluation(failing_hook, expected_hooks): + @strawberry.type + class Query: + @strawberry.field + def ping(self) -> str: + return "pong" + + extension = ExceptionTestingExtension(failing_hook) + schema = strawberry.Schema(query=Query, extensions=[extension]) + document = "query { ping }" + + extension.called_hooks = set() + schema.execute_sync(document) + assert extension.called_hooks == expected_hooks + + extension.called_hooks = set() + await schema.execute(document) + assert extension.called_hooks == expected_hooks + + +async def test_generic_exceptions_get_wrapped_in_a_graphql_error(): + exception = Exception("This should be wrapped in a GraphQL error") + + class MyExtension(SchemaExtension): + def on_parse(self): + raise exception + + @strawberry.type + class Query: + ping: str = "pong" + + schema = strawberry.Schema(query=Query, extensions=[MyExtension]) + query = "query { ping }" + + sync_result = schema.execute_sync(query) + assert len(sync_result.errors) == 1 + assert isinstance(sync_result.errors[0], GraphQLError) + assert sync_result.errors[0].original_error == exception + + async_result = await schema.execute(query) + assert len(async_result.errors) == 1 + assert isinstance(async_result.errors[0], GraphQLError) + assert async_result.errors[0].original_error == exception + + +async def test_graphql_errors_get_not_wrapped_in_a_graphql_error(): + exception = GraphQLError("This should not be wrapped in a GraphQL error") + + class MyExtension(SchemaExtension): + def on_parse(self): + raise exception + + @strawberry.type + class Query: + ping: str = "pong" + + schema = strawberry.Schema(query=Query, extensions=[MyExtension]) + query = "query { ping }" + + sync_result = schema.execute_sync(query) + assert len(sync_result.errors) == 1 + assert sync_result.errors[0] == exception + assert sync_result.errors[0].original_error is None + + async_result = await schema.execute(query) + assert len(async_result.errors) == 1 + assert async_result.errors[0] == exception + assert async_result.errors[0].original_error is None + + @pytest.mark.asyncio -async def test_dont_swallow_errors_in_parsing_hooks(): +async def test_non_parsing_errors_are_not_swallowed_by_parsing_hooks(): class MyExtension(SchemaExtension): def on_parse(self): raise Exception("This shouldn't be swallowed") @@ -643,14 +821,16 @@ def ping(self) -> str: schema = strawberry.Schema(query=Query, extensions=[MyExtension]) query = "query { string }" - with pytest.raises(Exception, match="This shouldn't be swallowed"): - schema.execute_sync(query) + sync_result = schema.execute_sync(query) + assert len(sync_result.errors) == 1 + assert sync_result.errors[0].message == "This shouldn't be swallowed" - with pytest.raises(Exception, match="This shouldn't be swallowed"): - await schema.execute(query) + async_result = await schema.execute(query) + assert len(async_result.errors) == 1 + assert async_result.errors[0].message == "This shouldn't be swallowed" -def test_on_parsing_end_called_when_errors(): +def test_on_parsing_end_is_called_with_parsing_errors(): execution_errors = False class MyExtension(SchemaExtension): @@ -696,7 +876,8 @@ def on_execute(self): class Query: food: str = "strawberry" - schema = strawberry.Schema(query=Query, extensions=[ExtensionB, ExtensionC]) + extensions = [ExtensionB, ExtensionC] + schema = strawberry.Schema(query=Query, extensions=extensions) query = """ query TestQuery { @@ -722,8 +903,9 @@ class Query: schema = strawberry.Schema(query=Query, extensions=[ExtensionA]) - with pytest.raises(RuntimeError, match="failed to complete synchronously"): - schema.execute_sync("query { food }") + result = schema.execute_sync("query { food }") + assert len(result.errors) == 1 + assert result.errors[0].message.endswith("failed to complete synchronously.") def test_extension_override_execution(): @@ -1021,24 +1203,8 @@ def hi(self) -> str: # Query not set on input query = "{ hi }" - with pytest.raises( - ValueError, match="Hook on_operation on <(.*)> must be callable, received 'ABC'" - ): - schema.execute_sync(query) - - -@pytest.mark.asyncio -async def test_calls_hooks_when_there_are_errors(async_extension): - @strawberry.type - class Query: - @strawberry.field - def hi(self) -> str: - raise Exception("This is an error") - - schema = strawberry.Schema(query=Query, extensions=[async_extension]) - - query = "{ hi }" - - result = await schema.execute(query) - assert result.errors - async_extension.perform_test() + result = schema.execute_sync(query) + assert len(result.errors) == 1 + assert isinstance(result.errors[0].original_error, ValueError) + assert result.errors[0].message.startswith("Hook on_operation on <") + assert result.errors[0].message.endswith("> must be callable, received 'ABC'") diff --git a/tests/schema/extensions/test_parser_cache.py b/tests/schema/extensions/test_parser_cache.py index 521c561ac7..20f7f540ad 100644 --- a/tests/schema/extensions/test_parser_cache.py +++ b/tests/schema/extensions/test_parser_cache.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from graphql import parse +from graphql import SourceLocation, parse import strawberry from strawberry.extensions import MaxTokensLimiter, ParserCache @@ -75,6 +75,26 @@ def ping(self) -> str: mock_parse.assert_called_with("query { hello }", max_tokens=20) +@patch("strawberry.schema.execute.parse", wraps=parse) +def test_parser_cache_extension_syntax_error(mock_parse): + @strawberry.type + class Query: + @strawberry.field + def hello(self) -> str: # pragma: no cover + return "world" + + schema = strawberry.Schema(query=Query, extensions=[ParserCache()]) + + query = "query { hello" + + result = schema.execute_sync(query) + + assert len(result.errors) == 1 + assert result.errors[0].message == "Syntax Error: Expected Name, found ." + assert result.errors[0].locations == [SourceLocation(line=1, column=14)] + assert mock_parse.call_count == 1 + + @patch("strawberry.schema.execute.parse", wraps=parse) def test_parser_cache_extension_max_size(mock_parse): @strawberry.type diff --git a/tests/schema/test_execution_errors.py b/tests/schema/test_execution_errors.py index e4fa50b2fe..513904e4b6 100644 --- a/tests/schema/test_execution_errors.py +++ b/tests/schema/test_execution_errors.py @@ -38,10 +38,13 @@ async def name(self) -> str: } """ - with pytest.raises(RuntimeError) as e: - schema.execute_sync(query) - - assert e.value.args[0] == "GraphQL execution failed to complete synchronously." + result = schema.execute_sync(query) + assert len(result.errors) == 1 + assert isinstance(result.errors[0].original_error, RuntimeError) + assert ( + result.errors[0].message + == "GraphQL execution failed to complete synchronously." + ) @pytest.mark.asyncio diff --git a/tests/types/test_execution.py b/tests/types/test_execution.py index a191c63dc8..c7bce4f2bc 100644 --- a/tests/types/test_execution.py +++ b/tests/types/test_execution.py @@ -1,5 +1,3 @@ -import pytest - import strawberry from strawberry.extensions import SchemaExtension @@ -150,8 +148,10 @@ def on_operation(self): schema = strawberry.Schema(Query, extensions=[MyExtension]) - with pytest.raises(RuntimeError): - schema.execute_sync("mutation { myMutation }") + result = schema.execute_sync("mutation { myMutation }") + assert len(result.errors) == 1 + assert isinstance(result.errors[0].original_error, RuntimeError) + assert result.errors[0].message == "No GraphQL document available" def test_error_when_accessing_operation_type_with_invalid_operation_name(): @@ -165,5 +165,7 @@ def on_parse(self): schema = strawberry.Schema(Query, extensions=[MyExtension]) - with pytest.raises(RuntimeError): - schema.execute_sync("query { ping }", operation_name="MyQuery") + result = schema.execute_sync("query { ping }", operation_name="MyQuery") + assert len(result.errors) == 1 + assert isinstance(result.errors[0].original_error, RuntimeError) + assert result.errors[0].message == "Can't get GraphQL operation type"