Skip to content

Commit

Permalink
Improve extensions error handling (#3217)
Browse files Browse the repository at this point in the history
* Clarify names of test cases

* Remove unnecessary except block

* Handle exceptions raised within extension hooks

* Adjust tests depending on previous behaviour

* Make ruff happy :)

* Lint

* Add release file

* chore: add tests for syntax errors on parser cache

* Update test

* Add # pragma: no cover

* Fix graphql errors were unnecessarily wrapped

* Restore test

* Restore

* Pending exceptions

* Restore code

* Quit after one exception

* Fix type

* Remove pending exception

* Fix test

---------

Co-authored-by: Connor Lewis <[email protected]>
Co-authored-by: Patrick Arminio <[email protected]>
  • Loading branch information
3 people authored Apr 17, 2024
1 parent 861cc0d commit 63dfc89
Show file tree
Hide file tree
Showing 6 changed files with 404 additions and 182 deletions.
20 changes: 20 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Release type: minor

Starting with this release, any error raised from within schema
extensions will abort the operation and is returned to the client.

This corresponds to the way we already handle field extension errors
and resolver errors.

This is particular useful for schema extensions performing checks early
in the request lifecycle, for example:

```python
class MaxQueryLengthExtension(SchemaExtension):
MAX_QUERY_LENGTH = 8192

async def on_operation(self):
if len(self.execution_context.query) > self.MAX_QUERY_LENGTH:
raise StrawberryGraphQLError(message="Query too large")
yield
```
289 changes: 150 additions & 139 deletions strawberry/schema/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,77 +88,82 @@ async def execute(
extensions=list(extensions),
)

async with extensions_runner.operation():
# Note: In graphql-core the schema would be validated here but in
# Strawberry we are validating it at initialisation time instead
if not execution_context.query:
raise MissingQueryError()

async with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
try:
async with extensions_runner.operation():
# Note: In graphql-core the schema would be validated here but in
# Strawberry we are validating it at initialisation time instead
if not execution_context.query:
raise MissingQueryError()

async with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
)

except GraphQLError as exc:
execution_context.errors = [exc]
process_errors([exc], execution_context)
return ExecutionResult(
data=None,
errors=[exc],
extensions=await extensions_runner.get_extensions_results(),
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)

except Exception as error: # pragma: no cover
error = GraphQLError(str(error), original_error=error)

execution_context.errors = [error]
process_errors([error], execution_context)

return ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)

if execution_context.operation_type not in allowed_operation_types:
raise InvalidOperationTypeError(execution_context.operation_type)

async with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)

async with extensions_runner.executing():
if not execution_context.result:
result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)

if isawaitable(result):
result = await cast(Awaitable["GraphQLExecutionResult"], result)

result = cast("GraphQLExecutionResult", result)
execution_context.result = result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

# Run the `Schema.process_errors` function here before
# extensions have a chance to modify them (see the MaskErrors
# extension). That way we can log the original errors but
# only return a sanitised version to the client.
process_errors(result.errors, execution_context)
if execution_context.operation_type not in allowed_operation_types:
raise InvalidOperationTypeError(execution_context.operation_type)

async with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)

async with extensions_runner.executing():
if not execution_context.result:
result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)

if isawaitable(result):
result = await cast(Awaitable["GraphQLExecutionResult"], result)

result = cast("GraphQLExecutionResult", result)
execution_context.result = result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

# Run the `Schema.process_errors` function here before
# extensions have a chance to modify them (see the MaskErrors
# extension). That way we can log the original errors but
# only return a sanitised version to the client.
process_errors(result.errors, execution_context)

except (MissingQueryError, InvalidOperationTypeError) as e:
raise e
except Exception as exc:
error = (
exc
if isinstance(exc, GraphQLError)
else GraphQLError(str(exc), original_error=exc)
)
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)

return ExecutionResult(
data=execution_context.result.data,
Expand All @@ -181,80 +186,86 @@ def execute_sync(
extensions=list(extensions),
)

with extensions_runner.operation():
# Note: In graphql-core the schema would be validated here but in
# Strawberry we are validating it at initialisation time instead
if not execution_context.query:
raise MissingQueryError()

with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
try:
with extensions_runner.operation():
# Note: In graphql-core the schema would be validated here but in
# Strawberry we are validating it at initialisation time instead
if not execution_context.query:
raise MissingQueryError()

with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
)

except GraphQLError as exc:
execution_context.errors = [exc]
process_errors([exc], execution_context)
return ExecutionResult(
data=None,
errors=[exc],
extensions=extensions_runner.get_extensions_results_sync(),
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=extensions_runner.get_extensions_results_sync(),
)

except Exception as error: # pragma: no cover
error = GraphQLError(str(error), original_error=error)

execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=extensions_runner.get_extensions_results_sync(),
)

if execution_context.operation_type not in allowed_operation_types:
raise InvalidOperationTypeError(execution_context.operation_type)

with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)

with extensions_runner.executing():
if not execution_context.result:
result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)

if isawaitable(result):
result = cast(Awaitable["GraphQLExecutionResult"], result)
ensure_future(result).cancel()
raise RuntimeError(
"GraphQL execution failed to complete synchronously."
if execution_context.operation_type not in allowed_operation_types:
raise InvalidOperationTypeError(execution_context.operation_type)

with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)

with extensions_runner.executing():
if not execution_context.result:
result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)

result = cast("GraphQLExecutionResult", result)
execution_context.result = result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

# Run the `Schema.process_errors` function here before
# extensions have a chance to modify them (see the MaskErrors
# extension). That way we can log the original errors but
# only return a sanitised version to the client.
process_errors(result.errors, execution_context)
if isawaitable(result):
result = cast(Awaitable["GraphQLExecutionResult"], result)
ensure_future(result).cancel()
raise RuntimeError(
"GraphQL execution failed to complete synchronously."
)

result = cast("GraphQLExecutionResult", result)
execution_context.result = result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

# Run the `Schema.process_errors` function here before
# extensions have a chance to modify them (see the MaskErrors
# extension). That way we can log the original errors but
# only return a sanitised version to the client.
process_errors(result.errors, execution_context)

except (MissingQueryError, InvalidOperationTypeError) as e:
raise e
except Exception as exc:
error = (
exc
if isinstance(exc, GraphQLError)
else GraphQLError(str(exc), original_error=exc)
)
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=extensions_runner.get_extensions_results_sync(),
)

return ExecutionResult(
data=execution_context.result.data,
Expand Down
Loading

0 comments on commit 63dfc89

Please sign in to comment.