Skip to content

Commit

Permalink
adds support for psycopg version 3
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Konoval committed Aug 10, 2024
1 parent c18a744 commit 0ebebd5
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 169 deletions.
39 changes: 39 additions & 0 deletions pgpubsub/compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

try:
from psycopg2._psycopg import Notify

class ConnectionWrapper:
def __init__(self, conn):
self.connection = conn

def poll(self):
self.connection.poll()

@property
def notifies(self):
return self.connection.notifies

@notifies.setter
def notifies(self, value: Notify) -> None:
self.connection.notifies = value

def stop(self):
pass

except ImportError:
from psycopg import Notify

class ConnectionWrapper:
def __init__(self, conn):
self.connection = conn
self.notifies = []
self.connection.add_notify_handler(self._notify_handler)

def _notify_handler(self, notification):
self.notifies.append(notification)

def poll(self):
self.connection.execute("SELECT 1")

def stop(self):
self.connection.remove_notify_handler(self._notify_handler)
60 changes: 32 additions & 28 deletions pgpubsub/listen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from django.core.management import execute_from_command_line
from django.db import connection, transaction
from django.db.models import Func, Value, Q
from psycopg2._psycopg import Notify

from pgpubsub import process_stored_notifications
from pgpubsub.channel import (
Expand All @@ -19,6 +18,7 @@
locate_channel,
registry,
)
from pgpubsub.compatibility import ConnectionWrapper, Notify
from pgpubsub.listeners import ListenerFilterProvider
from pgpubsub.models import Notification

Expand Down Expand Up @@ -65,32 +65,36 @@ def start_listen_in_a_process(
return process



def listen(
channels: Union[List[BaseChannel], List[str]] = None,
recover: bool = False,
autorestart_on_failure: bool = True,
start_method: str = 'spawn',
):
pg_connection = listen_to_channels(channels)
connection_wrapper = listen_to_channels(channels)

if recover:
process_stored_notifications(channels)
process_notifications(pg_connection)
try:
if recover:
process_stored_notifications(channels)
process_notifications(connection_wrapper)

logger.info('Listening for notifications... \n')
while POLL:
if select.select([pg_connection], [], [], 1) == ([], [], []):
pass
else:
try:
process_notifications(pg_connection)
except Exception as e:
logger.error(f'Encountered exception {e}', exc_info=e)
if autorestart_on_failure:
start_listen_in_a_process(
channels, recover, autorestart_on_failure, start_method
)
raise
logger.info('Listening for notifications... \n')
while POLL:
if select.select([connection_wrapper.connection], [], [], 1) == ([], [], []):
pass
else:
try:
process_notifications(connection_wrapper)
except Exception as e:
logger.error(f'Encountered exception {e}', exc_info=e)
if autorestart_on_failure:
start_listen_in_a_process(
channels, recover, autorestart_on_failure, start_method
)
raise
finally:
connection_wrapper.stop()


def listen_to_channels(channels: Union[List[BaseChannel], List[str]] = None):
Expand All @@ -109,21 +113,21 @@ def listen_to_channels(channels: Union[List[BaseChannel], List[str]] = None):
for channel in channels:
logger.info(f'Listening on {channel.name()}\n')
cursor.execute(f'LISTEN {channel.listen_safe_name()};')
return connection.connection
return ConnectionWrapper(connection.connection)


def process_notifications(pg_connection):
pg_connection.poll()
while pg_connection.notifies:
notification = pg_connection.notifies.pop(0)
def process_notifications(connection_wrapper):
connection_wrapper.poll()
while connection_wrapper.notifies:
notification = connection_wrapper.notifies.pop(0)
with transaction.atomic():
for processor in [
NotificationProcessor,
LockableNotificationProcessor,
NotificationRecoveryProcessor,
]:
try:
processor = processor(notification, pg_connection)
processor = processor(notification, connection_wrapper)
except InvalidNotificationProcessor:
continue
else:
Expand All @@ -132,10 +136,10 @@ def process_notifications(pg_connection):


class NotificationProcessor:
def __init__(self, notification: Notify, pg_connection):
def __init__(self, notification: Notify, connection_wrapper):
self.notification = notification
self.channel_cls, self.callbacks = Channel.get(notification.channel)
self.pg_connection = pg_connection
self.connection_wrapper = connection_wrapper
self.validate()

def validate(self):
Expand All @@ -150,7 +154,7 @@ def _execute(self):
channel = self.channel_cls.build_from_payload(
self.notification.payload, self.callbacks)
channel.execute_callbacks()
self.pg_connection.poll()
self.connection_wrapper.poll()


class CastToJSONB(Func):
Expand Down
Loading

0 comments on commit 0ebebd5

Please sign in to comment.