Skip to content

Commit

Permalink
Merge pull request #73 from romank0/payload-extras-upstream
Browse files Browse the repository at this point in the history
Allow to add context to payload
  • Loading branch information
PaulGilmartin authored Feb 15, 2024
2 parents d108979 + 1e1f648 commit 5e8c45f
Show file tree
Hide file tree
Showing 15 changed files with 489 additions and 43 deletions.
6 changes: 6 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Channels
Listeners
---------

.. autoclass:: pgpubsub.ListenerFilterProvider


.. autofunction:: pgpubsub.listener


Expand Down Expand Up @@ -48,6 +51,9 @@ Listeners
.. autofunction:: pgpubsub.trigger_listener


.. autofunction:: pgpubsub.set_notification_context


Notifiers
---------

Expand Down
87 changes: 87 additions & 0 deletions docs/payload_context.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
.. _payload_context:

Payload Context
===============

Sometimes it is beneficial to pass some contextual information from the trigger
to the trigger listener along the payload. Examples are:

- tracing information that allows to track complex request processing in a
multi component system
- in a multitenant system a tenant information to be able to identify the
tenant that peformed an operation that triggered a notification


This can be done by using **Payload Context**. This feature includes:

- ability to add an additional information to the payload in the trigger
- ability to filter by the fields in the context in the listener process
- ability to use ``context`` fields in the listener callbacks


Add ``context`` to payload in the trigger
-----------------------------------------

Before doing updates that produce notifications set the context that should be
passed using ``pgpubsub.set_notification_context`` function.

.. code-block:: python
from pgpubsub import set_notification_context
set_notification_context({'some-key': 'some-value'})
The setting is effective till the connection is closed. Alternatively the
setting ``PGPUBSUB_TX_BOUND_NOTIFICATION_CONTEXT=True`` can be used to clean
the context at the end of the current transanction.


Filter by ``context`` field in the trigger listener
---------------------------------------------------

Note: that the filtering is currently supported only for stored notifications that is
only for channels with ``lock_notifications = True``.

Define a class that implements the ``ListenerFilterProvider`` protocol and set
the option ``PGPUBSUB_LISTENER_FILTER`` to its fully qualified class name.

.. code-block:: python
from pgpubsub import ListenerFilterProvider
class TenantListenerFilterProvider(ListenerFilterProvider):
def get_filter(self) -> Q:
return Q(payload__context__tenant='my-tenant')
# django settings
PGPUBSUB_LISTENER_FILTER = 'myapp.whatever.TenantListenerFilterProvider'
This configuration will skip any notifications that do not have ``tenant`` field
equal to ``my-tenant`` in the payload's ``context`` field.

Pass ``context`` field to the trigger listener callback
-------------------------------------------------------

To enable this set ``PGPUBSUB_CONTEXT_TO_LISTENERS`` to ``True`` in django
settings and add a ``context`` parameter to the listener callback.

.. code-block:: python
# listeners.py
import pgpubsub
from pgpubsub.tests.channels import AuthorTriggerChannel
from pgpubsub.tests.models import Author, Post
@pgpubsub.post_insert_listener(AuthorTriggerChannel)
def create_first_post_for_author(
old: Author, new: Author, context: Dict[str, Any]
):
print(f'Creating first post for {new.name} with context={context}')
Post.objects.create(
author_id=new.pk,
content='Welcome! This is your first post',
date=datetime.date.today(),
)
# django settings
PGPUBSUB_PASS_CONTEXT_TO_LISTENERS = True
4 changes: 2 additions & 2 deletions docs/recovery.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ by supplying it with the ``--recover`` option. This will tell the listening proc
any missed stored notifications automatically when it starts up.


Note that this recovery option can be enabled whenever we use the `listen` management command
by supplying it with the `--recover` option. This will tell the listening processes to replay
Note that this recovery option can be enabled whenever we use the ``listen`` management command
by supplying it with the ``--recover`` option. This will tell the listening processes to replay
any missed stored notifications automatically when it starts up.

It is important to enable server side cursors in the django settings used by
Expand Down
1 change: 1 addition & 0 deletions docs/toc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Table of Contents
notifications
exactly_once_messaging
recovery
payload_context


.. toctree::
Expand Down
3 changes: 2 additions & 1 deletion pgpubsub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pgpubsub.channel import Channel, TriggerChannel
from pgpubsub.channel import Channel, TriggerChannel, set_notification_context
from pgpubsub.listeners import (
listener,
pre_save_listener,
Expand All @@ -10,6 +10,7 @@
pre_delete_listener,
post_delete_listener,
trigger_listener,
ListenerFilterProvider,
)
from pgpubsub.notify import notify, process_stored_notifications

70 changes: 64 additions & 6 deletions pgpubsub/channel.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import datetime
import hashlib
import inspect
import json
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from decimal import Decimal
import datetime
import inspect
import json
from pydoc import locate
from typing import Callable, Dict, Union, List
from typing import Any, Callable, Dict, List, Optional, Union

from django.apps import apps
from django.conf import settings
from django.core import serializers
from django.core.serializers.json import DjangoJSONEncoder
from django.db import models
from django.db import connection, connections, models
from django.db.utils import InternalError


registry = defaultdict(list)
Expand Down Expand Up @@ -143,6 +145,20 @@ class TriggerChannel(BaseChannel):
model = NotImplementedError
old: models.Model
new: models.Model
context: Optional[Dict[str, Any]] = None

@classmethod
def pass_context_to_listeners(cls) -> bool:
return getattr(settings, 'PGPUBSUB_PASS_CONTEXT_TO_LISTENERS', False)

@property
def signature(self):
return {
k: v for k, v in self.__dict__.items()
if k in self.__dataclass_fields__ and (
k != 'context' or self.pass_context_to_listeners()
)
}

@classmethod
def deserialize(cls, payload: Union[Dict, str]):
Expand All @@ -167,7 +183,10 @@ def deserialize(cls, payload: Union[Dict, str]):
new = next(new_deserialized_objects, None)
if new is not None:
new = new.object
return {'old': old, 'new': new}
fields = {'old': old, 'new': new}
if cls.pass_context_to_listeners():
fields['context'] = payload_dict.get('context', {})
return fields

@classmethod
def _build_model_serializer_data(cls, payload: Dict, state: str):
Expand Down Expand Up @@ -213,6 +232,45 @@ def _build_model_serializer_data(cls, payload: Dict, state: str):
return model_data


TX_ABORTED_ERROR_MESSAGE = (
'current transaction is aborted, commands ignored until end of transaction block'
)

def set_notification_context(
context: Dict[str, Any], using: Optional[str] = None
) -> None:
if using:
conn = connections[using]
else:
conn = connection
if conn.needs_rollback:
return
use_tx_bound_notification_context = getattr(
settings, 'PGPUBSUB_TX_BOUND_NOTIFICATION_CONTEXT', False
)
if use_tx_bound_notification_context and not conn.in_atomic_block:
raise RuntimeError(
'Transaction bound context can be only set in atomic block. '
'Either start transaction with `atomic` or do not use transaction bound '
'payload context via PGPUBSUB_TX_BOUND_NOTIFICATION_CONTEXT=False'
)
with conn.cursor() as cursor:
try:
if use_tx_bound_notification_context:
scope = 'LOCAL'
else:
scope = 'SESSION'
cursor.execute(
f'SET {scope} pgpubsub.notification_context = %s',
(json.dumps(context),)
)
except InternalError as e:
if TX_ABORTED_ERROR_MESSAGE in str(e):
return
else:
raise


def locate_channel(channel):
if isinstance(channel, str):
channel = locate(channel)
Expand Down
28 changes: 24 additions & 4 deletions pgpubsub/listen.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import importlib
import logging
import multiprocessing
import select
import sys
from typing import List, Optional, Union

from django.conf import settings
from django.core.management import execute_from_command_line
from django.db import connection, transaction
from django.db.models import Func, Value, Q
Expand All @@ -17,6 +19,7 @@
locate_channel,
registry,
)
from pgpubsub.listeners import ListenerFilterProvider
from pgpubsub.models import Notification

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -154,6 +157,18 @@ class CastToJSONB(Func):
template = '((%(expressions)s)::jsonb)'


def get_extra_filter() -> Q:
extra_filter_provider_fq_name = getattr(settings, 'PGPUBSUB_LISTENER_FILTER', None)
if extra_filter_provider_fq_name:
module = importlib.import_module(
'.'.join(extra_filter_provider_fq_name.split('.')[:-1])
)
clazz = getattr(module, extra_filter_provider_fq_name.split('.')[-1])
extra_filter_provider: ListenerFilterProvider = clazz()
return extra_filter_provider.get_filter()
else:
return Q()

class LockableNotificationProcessor(NotificationProcessor):

def validate(self):
Expand All @@ -163,12 +178,16 @@ def validate(self):
def process(self):
logger.info(
f'Processing notification for {self.channel_cls.name()}')
payload_filter = (
Q(payload=CastToJSONB(Value(self.notification.payload))) |
Q(payload=self.notification.payload)
)
payload_filter &= get_extra_filter()
notification = (
Notification.objects.select_for_update(
skip_locked=True).filter(
Q(payload=CastToJSONB(Value(self.notification.payload)))
| Q(payload=self.notification.payload),
channel=self.notification.channel,
payload_filter,
channel=self.notification.channel,
).first()
)
if notification is None:
Expand All @@ -189,9 +208,10 @@ def validate(self):

def process(self):
logger.info(f'Processing all notifications for channel {self.channel_cls.name()} \n')
payload_filter = Q(channel=self.notification.channel) & get_extra_filter()
notifications = (
Notification.objects.select_for_update(
skip_locked=True).filter(channel=self.notification.channel).iterator()
skip_locked=True).filter(payload_filter).iterator()
)
logger.info(f'Found notifications: {notifications}')
for notification in notifications:
Expand Down
8 changes: 7 additions & 1 deletion pgpubsub/listeners.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from functools import wraps
from typing import Union, Type
from typing import Protocol, Type, Union

import pgtrigger
from django.db.models import Q
from pgtrigger import Trigger, registered

from pgpubsub.channel import (
Expand Down Expand Up @@ -100,3 +101,8 @@ def wrapper(*args, **kwargs):
return callback(*args, **kwargs)
return wrapper
return _trig_listener


class ListenerFilterProvider(Protocol):
def get_filter(self) -> Q:
...
14 changes: 14 additions & 0 deletions pgpubsub/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
from django.db import connection
from pgpubsub.listen import listen_to_channels

@pytest.fixture()
def pg_connection():
return listen_to_channels()


@pytest.fixture
def tx_start_time(django_db_setup):
with connection.cursor() as cursor:
cursor.execute("SELECT now();")
return cursor.fetchone()[0]
5 changes: 5 additions & 0 deletions pgpubsub/tests/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

def simulate_listener_does_not_receive_notifications(pg_connection):
pg_connection.notifies = []
pg_connection.poll()
assert 0 == len(pg_connection.notifies)
16 changes: 12 additions & 4 deletions pgpubsub/tests/listeners.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
import datetime
from collections import defaultdict
from typing import Any, Dict, Optional

from django.db.transaction import atomic

Expand Down Expand Up @@ -38,17 +39,24 @@ def notify_post_owner(model_id: int, model_type: str, **kwargs):

@atomic
@pgpubsub.post_insert_listener(AuthorTriggerChannel)
def create_first_post_for_author(old: Author, new: Author):
def create_first_post_for_author(
old: Author, new: Author, context: Optional[Dict[str, Any]] = None
):
print(f'Creating first post for {new.name}')
content = 'Welcome! This is your first post'
if context and 'content' in context:
content = context.get('content')
Post.objects.create(
author_id=new.pk,
content='Welcome! This is your first post',
content=content,
date=datetime.date.today(),
)


@pgpubsub.post_insert_listener(AuthorTriggerChannel)
def another_author_trigger(old: Author, new: Author):
def another_author_trigger(
old: Author, new: Author, context: Optional[Dict[str, Any]] = None
):
print(f'Another author trigger')


Expand Down
Loading

0 comments on commit 5e8c45f

Please sign in to comment.