Skip to content

Commit

Permalink
Fix consistent return response from PubSubPullSensor (#42080)
Browse files Browse the repository at this point in the history
* fix consistent return response pubsubsensor

* removed messages_callback argument to pubsub trigger and using it in execute_complete

* updated variable name

* updates as per comments, added return types and refactored logic

* update types, tests and use inherit exception
  • Loading branch information
gopidesupavan authored Oct 2, 2024
1 parent 00589cf commit 64e972c
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 20 deletions.
24 changes: 21 additions & 3 deletions airflow/providers/google/cloud/sensors/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Sequence

from google.cloud import pubsub_v1
from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.configuration import conf
Expand All @@ -34,6 +35,10 @@
from airflow.utils.context import Context


class PubSubMessageTransformException(AirflowException):
"""Raise when messages failed to convert pubsub received format."""


class PubSubPullSensor(BaseSensorOperator):
"""
Pulls messages from a PubSub subscription and passes them through XCom.
Expand Down Expand Up @@ -170,22 +175,35 @@ def execute(self, context: Context) -> None:
subscription=self.subscription,
max_messages=self.max_messages,
ack_messages=self.ack_messages,
messages_callback=self.messages_callback,
poke_interval=self.poke_interval,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[str]]) -> str | list[str]:
"""Return immediately and relies on trigger to throw a success event. Callback for the trigger."""
def execute_complete(self, context: Context, event: dict[str, str | list[str]]) -> Any:
"""If messages_callback is provided, execute it; otherwise, return immediately with trigger event message."""
if event["status"] == "success":
self.log.info("Sensor pulls messages: %s", event["message"])
if self.messages_callback:
received_messages = self._convert_to_received_messages(event["message"])
_return_value = self.messages_callback(received_messages, context)
return _return_value

return event["message"]
self.log.info("Sensor failed: %s", event["message"])
raise AirflowException(event["message"])

def _convert_to_received_messages(self, messages: Any) -> list[ReceivedMessage]:
try:
received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in messages]
return received_messages
except Exception as e:
raise PubSubMessageTransformException(
f"Error converting triggerer event message back to received message format: {e}"
)

def _default_message_callback(
self,
pulled_messages: list[ReceivedMessage],
Expand Down
22 changes: 7 additions & 15 deletions airflow/providers/google/cloud/triggers/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,13 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Sequence
from typing import Any, AsyncIterator, Sequence

from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.providers.google.cloud.hooks.pubsub import PubSubAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.utils.context import Context


class PubsubPullTrigger(BaseTrigger):
"""
Expand All @@ -41,11 +38,6 @@ class PubsubPullTrigger(BaseTrigger):
:param ack_messages: If True, each message will be acknowledged
immediately rather than by any downstream tasks
:param gcp_conn_id: Reference to google cloud connection id
:param messages_callback: (Optional) Callback to process received messages.
Its return value will be saved to XCom.
If you are pulling large messages, you probably want to provide a custom callback.
If not provided, the default implementation will convert `ReceivedMessage` objects
into JSON-serializable dicts using `google.protobuf.json_format.MessageToDict` function.
:param poke_interval: polling period in seconds to check for the status
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
Expand All @@ -64,7 +56,6 @@ def __init__(
max_messages: int,
ack_messages: bool,
gcp_conn_id: str,
messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None,
poke_interval: float = 10.0,
impersonation_chain: str | Sequence[str] | None = None,
):
Expand All @@ -73,7 +64,6 @@ def __init__(
self.subscription = subscription
self.max_messages = max_messages
self.ack_messages = ack_messages
self.messages_callback = messages_callback
self.poke_interval = poke_interval
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
Expand All @@ -88,7 +78,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"subscription": self.subscription,
"max_messages": self.max_messages,
"ack_messages": self.ack_messages,
"messages_callback": self.messages_callback,
"poke_interval": self.poke_interval,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
Expand All @@ -106,7 +95,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
):
if self.ack_messages:
await self.message_acknowledgement(pulled_messages)
yield TriggerEvent({"status": "success", "message": pulled_messages})

messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]

yield TriggerEvent({"status": "success", "message": messages_json})
return
self.log.info("Sleeping for %s seconds.", self.poke_interval)
await asyncio.sleep(self.poke_interval)
Expand Down
48 changes: 48 additions & 0 deletions tests/providers/google/cloud/sensors/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from unittest import mock

import pytest
from google.cloud import pubsub_v1
from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.exceptions import AirflowException, TaskDeferred
Expand Down Expand Up @@ -197,3 +198,50 @@ def test_pubsub_pull_sensor_async_execute_complete(self):
with mock.patch.object(operator.log, "info") as mock_log_info:
operator.execute_complete(context={}, event={"status": "success", "message": test_message})
mock_log_info.assert_called_with("Sensor pulls messages: %s", test_message)

@mock.patch("airflow.providers.google.cloud.sensors.pubsub.PubSubHook")
def test_pubsub_pull_sensor_async_execute_complete_use_message_callback(self, mock_hook):
test_message = [
{
"ack_id": "UAYWLF1GSFE3GQhoUQ5PXiM_NSAoRRIJB08CKF15MU0sQVhwaFENGXJ9YHxrUxsDV0ECel1RGQdoTm11H4GglfRLQ1RrWBIHB01Vel5TEwxoX11wBnm4vPO6v8vgfwk9OpX-8tltO6ywsP9GZiM9XhJLLD5-LzlFQV5AEkwkDERJUytDCypYEU4EISE-MD5FU0Q",
"message": {
"data": "aGkgZnJvbSBjbG91ZCBjb25zb2xlIQ==",
"message_id": "12165864188103151",
"publish_time": "2024-08-28T11:49:50.962Z",
"attributes": {},
"ordering_key": "",
},
"delivery_attempt": 0,
}
]

received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in test_message]

messages_callback_return_value = "custom_message_from_callback"

def messages_callback(
pulled_messages: list[ReceivedMessage],
context: dict[str, Any],
):
assert pulled_messages == received_messages

assert isinstance(context, dict)
for key in context.keys():
assert isinstance(key, str)

return messages_callback_return_value

operator = PubSubPullSensor(
task_id="test_task",
ack_messages=True,
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
deferrable=True,
messages_callback=messages_callback,
)
mock_hook.return_value.pull.return_value = received_messages

with mock.patch.object(operator.log, "info") as mock_log_info:
resp = operator.execute_complete(context={}, event={"status": "success", "message": test_message})
mock_log_info.assert_called_with("Sensor pulls messages: %s", test_message)
assert resp == messages_callback_return_value
55 changes: 53 additions & 2 deletions tests/providers/google/cloud/triggers/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
# under the License.
from __future__ import annotations

from unittest import mock

import pytest
from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger
from airflow.triggers.base import TriggerEvent

TEST_POLL_INTERVAL = 10
TEST_GCP_CONN_ID = "google_cloud_default"
Expand All @@ -34,13 +38,25 @@ def trigger():
subscription="subscription",
max_messages=MAX_MESSAGES,
ack_messages=ACK_MESSAGES,
messages_callback=None,
poke_interval=TEST_POLL_INTERVAL,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
)


async def generate_messages(count: int) -> list[ReceivedMessage]:
return [
ReceivedMessage(
ack_id=f"{i}",
message={
"data": f"Message {i}".encode(),
"attributes": {"type": "generated message"},
},
)
for i in range(1, count + 1)
]


class TestPubsubPullTrigger:
def test_async_pubsub_pull_trigger_serialization_should_execute_successfully(self, trigger):
"""
Expand All @@ -54,8 +70,43 @@ def test_async_pubsub_pull_trigger_serialization_should_execute_successfully(sel
"subscription": "subscription",
"max_messages": MAX_MESSAGES,
"ack_messages": ACK_MESSAGES,
"messages_callback": None,
"poke_interval": TEST_POLL_INTERVAL,
"gcp_conn_id": TEST_GCP_CONN_ID,
"impersonation_chain": None,
}

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubAsyncHook.pull")
async def test_async_pubsub_pull_trigger_return_event(self, mock_pull):
mock_pull.return_value = generate_messages(1)
trigger = PubsubPullTrigger(
project_id=PROJECT_ID,
subscription="subscription",
max_messages=MAX_MESSAGES,
ack_messages=False,
poke_interval=TEST_POLL_INTERVAL,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
)

expected_event = TriggerEvent(
{
"status": "success",
"message": [
{
"ack_id": "1",
"message": {
"data": "TWVzc2FnZSAx",
"attributes": {"type": "generated message"},
"message_id": "",
"ordering_key": "",
},
"delivery_attempt": 0,
}
],
}
)

response = await trigger.run().asend(None)

assert response == expected_event

0 comments on commit 64e972c

Please sign in to comment.