Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per-operation context + initial payload #49

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 55 additions & 37 deletions channels_graphql_ws/graphql_ws_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
import promise
import rx

from .scope_as_context import ScopeAsContext
from .operation_context import OperationContext
from .serializer import Serializer

# Module logger.
Expand Down Expand Up @@ -547,7 +547,7 @@ async def _on_gql_start(self, operation_id, payload):

# Create object-like context (like in `Query` or `Mutation`)
# from the dict-like one provided by the Channels.
context = ScopeAsContext(self.scope)
context = OperationContext(self.scope)

# Adding channel name to the context because it seems to be
# useful for some use cases, take a loot at the issue from
Expand Down Expand Up @@ -671,7 +671,12 @@ def register_middleware(next_middleware, root, info, *args, **kwds):
await self._send_gql_complete(operation_id)

async def _register_subscription(
self, operation_id, groups, publish_callback, unsubscribed_callback
self,
operation_id,
groups,
publish_callback,
unsubscribed_callback,
initial_payload,
):
"""Register a new subscription when client subscribes.

Expand Down Expand Up @@ -701,61 +706,74 @@ async def _register_subscription(
# `_sids_by_group` without any locks.
self._assert_thread()

# The subject we will trigger on the `broadcast` message.
trigger = rx.subjects.Subject()

# The subscription notification queue.
notification_queue = asyncio.Queue(
maxsize=self.subscription_notification_queue_limit
)

# Enqueue the initial payload.
if initial_payload is not self.SKIP:
notification_queue.put_nowait(Serializer.serialize(initial_payload))

# Start an endless task which listens the `notification_queue`
# and invokes subscription "resolver" on new notifications.
async def notifier():
async def notifier(observer: rx.Observer):
"""Watch the notification queue and notify clients."""

# Assert we run in a proper thread.
self._assert_thread()
while True:
payload = await notification_queue.get()
serialized_payload = await notification_queue.get()

# Run a subscription's `publish` method (invoked by the
# `trigger.on_next` function) within the threadpool used
# `observer.on_next` function) within the threadpool used
# for processing other GraphQL resolver functions.
# NOTE: `lambda` is important to run the deserialization
# NOTE: it is important to run the deserialization
# in the worker thread as well.
await self._run_in_worker(
lambda: trigger.on_next(Serializer.deserialize(payload))
)
def workload():
try:
payload = Serializer.deserialize(serialized_payload)
except Exception as ex: # pylint: disable=broad-except
observer.on_error(f"Cannot deserialize payload. {ex}")
else:
observer.on_next(payload)

await self._run_in_worker(workload)

# Message processed. This allows `Queue.join` to work.
notification_queue.task_done()

# Enqueue the `publish` method execution. But do not notify
# clients when `publish` returns `SKIP`.
stream = trigger.map(publish_callback).filter( # pylint: disable=no-member
lambda publish_returned: publish_returned is not self.SKIP
)
def push_payloads(observer: rx.Observer):
# Start listening for broadcasts (subscribe to the Channels
# groups), spawn the notification processing task and put
# subscription information into the registry.
# NOTE: Update of `_sids_by_group` & `_subscriptions` must be
# atomic i.e. without `awaits` in between.
for group in groups:
self._sids_by_group.setdefault(group, []).append(operation_id)
notifier_task = self._spawn_background_task(notifier(observer))
self._subscriptions[operation_id] = self._SubInf(
groups=groups,
sid=operation_id,
unsubscribed_callback=unsubscribed_callback,
notification_queue=notification_queue,
notifier_task=notifier_task,
)

# Start listening for broadcasts (subscribe to the Channels
# groups), spawn the notification processing task and put
# subscription information into the registry.
# NOTE: Update of `_sids_by_group` & `_subscriptions` must be
# atomic i.e. without `awaits` in between.
waitlist = []
for group in groups:
self._sids_by_group.setdefault(group, []).append(operation_id)
waitlist.append(self._channel_layer.group_add(group, self.channel_name))
notifier_task = self._spawn_background_task(notifier())
self._subscriptions[operation_id] = self._SubInf(
groups=groups,
sid=operation_id,
unsubscribed_callback=unsubscribed_callback,
notification_queue=notification_queue,
notifier_task=notifier_task,
await asyncio.wait(
[
self._channel_layer.group_add(group, self.channel_name)
for group in groups
]
)

await asyncio.wait(waitlist)

return stream
# Enqueue the `publish` method execution. But do not notify
# clients when `publish` returns `SKIP`.
return (
rx.Observable.create(push_payloads) # pylint: disable=no-member
.map(publish_callback)
.filter(lambda publish_returned: publish_returned is not self.SKIP)
)

async def _on_gql_stop(self, operation_id):
"""Process the STOP message.
Expand Down
33 changes: 33 additions & 0 deletions channels_graphql_ws/operation_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Just `OperationContext` class."""

from channels_graphql_ws.scope_as_context import ScopeAsContext


class OperationContext(ScopeAsContext):
"""
The context intended to use in methods of Graphene classes as `info.context`.

This class provides two public properties:
1. `scope` - per-connection context. This is the `scope` of Django Channels.
2. `operation_context` - per-operation context. Empty. Fill free to store your's
data here.

For backward compatibility:
- Method `_asdict` returns the `scope`.
- Other attributes are routed to the `scope`.
"""

def __init__(self, scope: dict):
"""Nothing interesting here."""
super().__init__(scope)
self._operation_context: dict = {}

@property
def scope(self) -> dict:
"""Return the scope."""
return self._scope

@property
def operation_context(self) -> dict:
"""Return the per-operation context."""
return self._operation_context
2 changes: 1 addition & 1 deletion channels_graphql_ws/scope_as_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class ScopeAsContext:
"""Wrapper to make Channels `scope` appear as an `info.context`."""

def __init__(self, scope):
def __init__(self, scope: dict):
"""Remember given `scope`."""
self._scope = scope

Expand Down
5 changes: 4 additions & 1 deletion channels_graphql_ws/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def __init_subclass_with_meta__(
_meta.subscribe = get_function(subscribe)
_meta.publish = get_function(publish)
_meta.unsubscribed = get_function(unsubscribed)
_meta.initial_payload = options.get("initial_payload", cls.SKIP)

super().__init_subclass_with_meta__(_meta=_meta, **options)

Expand Down Expand Up @@ -422,7 +423,9 @@ def unsubscribed_callback():
# `subscribe`.
return result

return register_subscription(groups, publish_callback, unsubscribed_callback)
return register_subscription(
groups, publish_callback, unsubscribed_callback, cls._meta.initial_payload
)

@classmethod
def _group_name(cls, group=None):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,9 +732,9 @@ async def test_message_order_in_subscribe_unsubscribe_all_loop(
'complete' message.
"""

NUMBER_OF_UNSUBSCRIBE_CALLS = 50 # pylint: disable=invalid-name
NUMBER_OF_UNSUBSCRIBE_CALLS = 100 # pylint: disable=invalid-name
# Delay in seconds.
DELAY_BETWEEN_UNSUBSCRIBE_CALLS = 0.01 # pylint: disable=invalid-name
DELAY_BETWEEN_UNSUBSCRIBE_CALLS = 0.02 # pylint: disable=invalid-name
# Gradually stop the test if time is up.
TIME_BORDER = 20 # pylint: disable=invalid-name

Expand Down