Skip to content

Commit

Permalink
Extract operation name and type from execution context (#1286)
Browse files Browse the repository at this point in the history
* Extract operation name and type from execution context

* Fix open telemetry extension and make the operation name fetching more
forgiving

* Code review updates

* Add RELEASE file

* Add test for runtime exception

* Rename variable

* Change version to minor

Co-authored-by: Patrick Arminio <[email protected]>
  • Loading branch information
jkimbo and patrick91 authored Dec 7, 2021
1 parent 02e500e commit 91cbfb8
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 7 deletions.
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Release type: minor

This release `operation_type` to the `ExecutionContext` type that is available
in extensions. It also gets the `operation_name` from the query if one isn't
provided by the client.
13 changes: 11 additions & 2 deletions strawberry/extensions/tracing/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def __init__(
self.execution_context = execution_context

def on_request_start(self):
self._operation_name = self.execution_context.operation_name
span_name = (
f"GraphQL Query: {self.execution_context.operation_name}"
if self.execution_context.operation_name
f"GraphQL Query: {self._operation_name}"
if self._operation_name
else "GraphQL Query"
)

Expand All @@ -58,6 +59,14 @@ def on_request_start(self):
)

def on_request_end(self):
# If the client doesn't provide an operation name then GraphQL will
# execute the first operation in the query string. This might be a named
# operation but we don't know until the parsing stage has finished. If
# that's the case we want to update the span name so that we have a more
# useful name in our trace.
if not self._operation_name and self.execution_context.operation_name:
span_name = f"GraphQL Query: {self.execution_context.operation_name}"
self._span_holder[RequestStage.REQUEST].update_name(span_name)
self._span_holder[RequestStage.REQUEST].end()

def on_validation_start(self):
Expand Down
4 changes: 2 additions & 2 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def execute(
context=context_value,
root_value=root_value,
variables=variable_values,
operation_name=operation_name,
provided_operation_name=operation_name,
)

result = await execute(
Expand Down Expand Up @@ -167,7 +167,7 @@ def execute_sync(
context=context_value,
root_value=root_value,
variables=variable_values,
operation_name=operation_name,
provided_operation_name=operation_name,
)

result = execute_sync(
Expand Down
67 changes: 64 additions & 3 deletions strawberry/types/execution.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,98 @@
import dataclasses
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast

from typing_extensions import Literal

from graphql import (
ASTValidationRule,
ExecutionResult as GraphQLExecutionResult,
specified_rules,
)
from graphql.error.graphql_error import GraphQLError
from graphql.language import DocumentNode
from graphql.language import DocumentNode, OperationDefinitionNode


if TYPE_CHECKING:
from strawberry.schema import Schema


GraphqlOperationTypes = Literal["QUERY", "MUTATION"]


@dataclasses.dataclass
class ExecutionContext:
query: str
schema: "Schema"
context: Any = None
variables: Optional[Dict[str, Any]] = None
operation_name: Optional[str] = None
root_value: Optional[Any] = None
validation_rules: Tuple[Type[ASTValidationRule], ...] = dataclasses.field(
default_factory=lambda: tuple(specified_rules)
)

# The operation name that is provided by the request
provided_operation_name: dataclasses.InitVar[Optional[str]] = None

# Values that get populated during the GraphQL execution so that they can be
# accessed by extensions
graphql_document: Optional[DocumentNode] = None
errors: Optional[List[GraphQLError]] = None
result: Optional[GraphQLExecutionResult] = None

def __post_init__(self, provided_operation_name):
self._provided_operation_name = provided_operation_name

@property
def operation_name(self) -> Optional[str]:
if self._provided_operation_name:
return self._provided_operation_name

definition = self._get_first_operation()
if not definition:
return None

if not definition.name:
return None

return definition.name.value

@property
def operation_type(self) -> GraphqlOperationTypes:
definition: Optional[OperationDefinitionNode] = None

graphql_document = self.graphql_document
if not graphql_document:
raise RuntimeError("No GraphQL document available")

# If no operation_name has been specified then use the first
# OperationDefinitionNode
if not self._provided_operation_name:
definition = self._get_first_operation()
else:
for d in graphql_document.definitions:
d = cast(OperationDefinitionNode, d)
if d.name and d.name.value == self._provided_operation_name:
definition = d
break

if not definition:
raise RuntimeError("Can't get GraphQL operation type")

return cast(GraphqlOperationTypes, definition.operation.name)

def _get_first_operation(self) -> Optional[OperationDefinitionNode]:
graphql_document = self.graphql_document
if not graphql_document:
return None

definition: Optional[OperationDefinitionNode] = None
for d in graphql_document.definitions:
if isinstance(d, OperationDefinitionNode):
definition = d
break

return definition


@dataclasses.dataclass
class ExecutionResult:
Expand Down
31 changes: 31 additions & 0 deletions tests/schema/extensions/test_opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,37 @@ async def test_open_tracing_uses_operation_name(global_tracer_mock, mocker):
)


@pytest.mark.asyncio
async def test_open_tracing_gets_operation_name(global_tracer_mock, mocker):
schema = strawberry.Schema(query=Query, extensions=[OpenTelemetryExtension])
query = """
query Example {
person {
name
}
}
"""

tracers = []

def generate_trace(*args, **kwargs):
nonlocal tracers
tracer = mocker.Mock()
tracers.append(tracer)
return tracer

global_tracer_mock.return_value.start_span.side_effect = generate_trace

await schema.execute(query)

tracers[0].update_name.assert_has_calls(
[
# if operation_name is supplied it is added to this span's tag
mocker.call("GraphQL Query: Example"),
]
)


@pytest.mark.asyncio
async def test_tracing_add_kwargs(global_tracer_mock, mocker):
@strawberry.type
Expand Down
165 changes: 165 additions & 0 deletions tests/types/test_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import pytest

import strawberry
from strawberry.extensions import Extension


@strawberry.type
class Query:
@strawberry.field
def ping(self) -> str:
return "pong"


def test_execution_context_operation_name_and_type():
operation_name = None
operation_type = None

class MyExtension(Extension):
def on_request_end(self):
nonlocal operation_name
nonlocal operation_type

execution_context = self.execution_context

operation_name = execution_context.operation_name
operation_type = execution_context.operation_type

schema = strawberry.Schema(Query, extensions=[MyExtension])

result = schema.execute_sync("{ ping }")
assert not result.errors

assert operation_name is None
assert operation_type == "QUERY"

# Try again with an operation_name
result = schema.execute_sync("query MyOperation { ping }")
assert not result.errors

assert operation_name == "MyOperation"
assert operation_type == "QUERY"

# Try again with an operation_name override
result = schema.execute_sync(
"""
query MyOperation { ping }
query MyOperation2 { ping }
""",
operation_name="MyOperation2",
)
assert not result.errors

assert operation_name == "MyOperation2"
assert operation_type == "QUERY"


def test_execution_context_operation_type_mutation():
operation_name = None
operation_type = None

class MyExtension(Extension):
def on_request_end(self):
nonlocal operation_name
nonlocal operation_type

execution_context = self.execution_context

operation_name = execution_context.operation_name
operation_type = execution_context.operation_type

@strawberry.type
class Mutation:
@strawberry.mutation
def my_mutation(self) -> str:
return "hi"

schema = strawberry.Schema(Query, mutation=Mutation, extensions=[MyExtension])

result = schema.execute_sync("mutation { myMutation }")
assert not result.errors

assert operation_name is None
assert operation_type == "MUTATION"

# Try again with an operation_name
result = schema.execute_sync("mutation MyMutation { myMutation }")
assert not result.errors

assert operation_name == "MyMutation"
assert operation_type == "MUTATION"

# Try again with an operation_name override
result = schema.execute_sync(
"""
mutation MyMutation { myMutation }
mutation MyMutation2 { myMutation }
""",
operation_name="MyMutation2",
)
assert not result.errors

assert operation_name == "MyMutation2"
assert operation_type == "MUTATION"


def test_execution_context_operation_name_and_type_with_fragmenets():
operation_name = None
operation_type = None

class MyExtension(Extension):
def on_request_end(self):
nonlocal operation_name
nonlocal operation_type

execution_context = self.execution_context

operation_name = execution_context.operation_name
operation_type = execution_context.operation_type

schema = strawberry.Schema(Query, extensions=[MyExtension])

result = schema.execute_sync(
"""
fragment MyFragment on Query {
ping
}
query MyOperation {
ping
...MyFragment
}
"""
)
assert not result.errors

assert operation_name == "MyOperation"
assert operation_type == "QUERY"


def test_error_when_accessing_operation_type_before_parsing():
class MyExtension(Extension):
def on_request_start(self):
execution_context = self.execution_context

# This should raise a RuntimeError
execution_context.operation_type

schema = strawberry.Schema(Query, extensions=[MyExtension])

with pytest.raises(RuntimeError):
schema.execute_sync("mutation { myMutation }")


def test_error_when_accessing_operation_type_with_invalid_operation_name():
class MyExtension(Extension):
def on_parsing_end(self):
execution_context = self.execution_context

# This should raise a RuntimeError
execution_context.operation_type

schema = strawberry.Schema(Query, extensions=[MyExtension])

with pytest.raises(RuntimeError):
schema.execute_sync("query { ping }", operation_name="MyQuery")

0 comments on commit 91cbfb8

Please sign in to comment.