Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract operation name and type from execution context #1286

Merged
merged 7 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Release type: patch
patrick91 marked this conversation as resolved.
Show resolved Hide resolved

This release `operation_type` to the `ExecutionContext` type that is available
in extensions. It also gets the `operation_name` from the query if one isn't
provided by the client.
13 changes: 11 additions & 2 deletions strawberry/extensions/tracing/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def __init__(
self.execution_context = execution_context

def on_request_start(self):
self._operation_name = self.execution_context.operation_name
span_name = (
f"GraphQL Query: {self.execution_context.operation_name}"
if self.execution_context.operation_name
f"GraphQL Query: {self._operation_name}"
if self._operation_name
else "GraphQL Query"
)

Expand All @@ -58,6 +59,14 @@ def on_request_start(self):
)

def on_request_end(self):
# If the client doesn't provide an operation name then GraphQL will
# execute the first operation in the query string. This might be a named
# operation but we don't know until the parsing stage has finished. If
# that's the case we want to update the span name so that we have a more
# useful name in our trace.
if not self._operation_name and self.execution_context.operation_name:
span_name = f"GraphQL Query: {self.execution_context.operation_name}"
self._span_holder[RequestStage.REQUEST].update_name(span_name)
self._span_holder[RequestStage.REQUEST].end()

def on_validation_start(self):
Expand Down
4 changes: 2 additions & 2 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def execute(
context=context_value,
root_value=root_value,
variables=variable_values,
operation_name=operation_name,
provided_operation_name=operation_name,
)

result = await execute(
Expand Down Expand Up @@ -168,7 +168,7 @@ def execute_sync(
context=context_value,
root_value=root_value,
variables=variable_values,
operation_name=operation_name,
provided_operation_name=operation_name,
)

result = execute_sync(
Expand Down
67 changes: 64 additions & 3 deletions strawberry/types/execution.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,98 @@
import dataclasses
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast

from typing_extensions import Literal

from graphql import (
ASTValidationRule,
ExecutionResult as GraphQLExecutionResult,
specified_rules,
)
from graphql.error.graphql_error import GraphQLError
from graphql.language import DocumentNode
from graphql.language import DocumentNode, OperationDefinitionNode


if TYPE_CHECKING:
from strawberry.schema import Schema


GraphqlOperationTypes = Literal["QUERY", "MUTATION"]


@dataclasses.dataclass
class ExecutionContext:
query: str
schema: "Schema"
context: Any = None
variables: Optional[Dict[str, Any]] = None
operation_name: Optional[str] = None
root_value: Optional[Any] = None
validation_rules: Tuple[Type[ASTValidationRule], ...] = dataclasses.field(
default_factory=lambda: tuple(specified_rules)
)

# The operation name that is provided by the request
provided_operation_name: dataclasses.InitVar[Optional[str]] = None

# Values that get populated during the GraphQL execution so that they can be
# accessed by extensions
graphql_document: Optional[DocumentNode] = None
errors: Optional[List[GraphQLError]] = None
result: Optional[GraphQLExecutionResult] = None

def __post_init__(self, provided_operation_name):
self._provided_operation_name = provided_operation_name

@property
def operation_name(self) -> Optional[str]:
if self._provided_operation_name:
return self._provided_operation_name

definition = self._get_first_operation()
if not definition:
return None

if not definition.name:
return None

return definition.name.value

@property
def operation_type(self) -> GraphqlOperationTypes:
definition: Optional[OperationDefinitionNode] = None

graphql_document = self.graphql_document
if not graphql_document:
raise RuntimeError("No GraphQL document available")

# If no operation_name has been specified then use the first
# OperationDefinitionNode
if not self._provided_operation_name:
definition = self._get_first_operation()
else:
for d in graphql_document.definitions:
d = cast(OperationDefinitionNode, d)
if d.name and d.name.value == self._provided_operation_name:
definition = d
break

if not definition:
raise RuntimeError("Can't get GraphQL operation type")

return cast(GraphqlOperationTypes, definition.operation.name)

def _get_first_operation(self) -> Optional[OperationDefinitionNode]:
graphql_document = self.graphql_document
if not graphql_document:
return None

definition: Optional[OperationDefinitionNode] = None
for d in graphql_document.definitions:
if isinstance(d, OperationDefinitionNode):
definition = d
break

return definition


@dataclasses.dataclass
class ExecutionResult:
Expand Down
31 changes: 31 additions & 0 deletions tests/schema/extensions/test_opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,37 @@ async def test_open_tracing_uses_operation_name(global_tracer_mock, mocker):
)


@pytest.mark.asyncio
async def test_open_tracing_gets_operation_name(global_tracer_mock, mocker):
schema = strawberry.Schema(query=Query, extensions=[OpenTelemetryExtension])
query = """
query Example {
person {
name
}
}
"""

tracers = []

def generate_trace(*args, **kwargs):
nonlocal tracers
tracer = mocker.Mock()
tracers.append(tracer)
return tracer

global_tracer_mock.return_value.start_span.side_effect = generate_trace

await schema.execute(query)

tracers[0].update_name.assert_has_calls(
[
# if operation_name is supplied it is added to this span's tag
mocker.call("GraphQL Query: Example"),
]
)


@pytest.mark.asyncio
async def test_tracing_add_kwargs(global_tracer_mock, mocker):
@strawberry.type
Expand Down
165 changes: 165 additions & 0 deletions tests/types/test_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import pytest

import strawberry
from strawberry.extensions import Extension


@strawberry.type
class Query:
@strawberry.field
def ping(self) -> str:
return "pong"


def test_execution_context_operation_name_and_type():
operation_name = None
operation_type = None

class MyExtension(Extension):
def on_request_end(self):
nonlocal operation_name
nonlocal operation_type

execution_context = self.execution_context

operation_name = execution_context.operation_name
operation_type = execution_context.operation_type

schema = strawberry.Schema(Query, extensions=[MyExtension])

result = schema.execute_sync("{ ping }")
assert not result.errors

assert operation_name is None
assert operation_type == "QUERY"

# Try again with an operation_name
result = schema.execute_sync("query MyOperation { ping }")
assert not result.errors

assert operation_name == "MyOperation"
assert operation_type == "QUERY"

# Try again with an operation_name override
result = schema.execute_sync(
"""
query MyOperation { ping }
query MyOperation2 { ping }
""",
operation_name="MyOperation2",
)
assert not result.errors

assert operation_name == "MyOperation2"
assert operation_type == "QUERY"


def test_execution_context_operation_type_mutation():
operation_name = None
operation_type = None

class MyExtension(Extension):
def on_request_end(self):
nonlocal operation_name
nonlocal operation_type

execution_context = self.execution_context

operation_name = execution_context.operation_name
operation_type = execution_context.operation_type

@strawberry.type
class Mutation:
@strawberry.mutation
def my_mutation(self) -> str:
return "hi"

schema = strawberry.Schema(Query, mutation=Mutation, extensions=[MyExtension])

result = schema.execute_sync("mutation { myMutation }")
assert not result.errors

assert operation_name is None
assert operation_type == "MUTATION"

# Try again with an operation_name
result = schema.execute_sync("mutation MyMutation { myMutation }")
assert not result.errors

assert operation_name == "MyMutation"
assert operation_type == "MUTATION"

# Try again with an operation_name override
result = schema.execute_sync(
"""
mutation MyMutation { myMutation }
mutation MyMutation2 { myMutation }
""",
operation_name="MyMutation2",
)
assert not result.errors

assert operation_name == "MyMutation2"
assert operation_type == "MUTATION"


def test_execution_context_operation_name_and_type_with_fragmenets():
operation_name = None
operation_type = None

class MyExtension(Extension):
def on_request_end(self):
nonlocal operation_name
nonlocal operation_type

execution_context = self.execution_context

operation_name = execution_context.operation_name
operation_type = execution_context.operation_type

schema = strawberry.Schema(Query, extensions=[MyExtension])

result = schema.execute_sync(
"""
fragment MyFragment on Query {
ping
}

query MyOperation {
ping
...MyFragment
}
"""
)
assert not result.errors

assert operation_name == "MyOperation"
assert operation_type == "QUERY"


def test_error_when_accessing_operation_type_before_parsing():
class MyExtension(Extension):
def on_request_start(self):
execution_context = self.execution_context

# This should raise a RuntimeError
execution_context.operation_type

schema = strawberry.Schema(Query, extensions=[MyExtension])

with pytest.raises(RuntimeError):
schema.execute_sync("mutation { myMutation }")


def test_error_when_accessing_operation_type_with_invalid_operation_name():
class MyExtension(Extension):
def on_parsing_end(self):
execution_context = self.execution_context

# This should raise a RuntimeError
execution_context.operation_type

schema = strawberry.Schema(Query, extensions=[MyExtension])

with pytest.raises(RuntimeError):
schema.execute_sync("query { ping }", operation_name="MyQuery")