From b41cf629c8624b906ed29760e14037e8d2f9a370 Mon Sep 17 00:00:00 2001 From: Eugene <53026723+e-galan@users.noreply.github.com> Date: Thu, 18 Apr 2024 06:38:11 +0000 Subject: [PATCH] Deferrable mode for Dataflow sensors (#37693) --- .../providers/google/cloud/hooks/dataflow.py | 103 ++- .../google/cloud/sensors/dataflow.py | 282 ++++++-- .../google/cloud/triggers/dataflow.py | 506 +++++++++++++- .../operators/cloud/dataflow.rst | 32 + .../google/cloud/hooks/test_dataflow.py | 55 +- .../google/cloud/sensors/test_dataflow.py | 462 ++++++++++++- .../google/cloud/triggers/test_dataflow.py | 619 +++++++++++++++++- .../example_dataflow_sensors_deferrable.py | 190 ++++++ 8 files changed, 2176 insertions(+), 73 deletions(-) create mode 100644 tests/system/providers/google/cloud/dataflow/example_dataflow_sensors_deferrable.py diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index 64f309709c74..a9bf802b14be 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -31,7 +31,17 @@ from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence, TypeVar, cast from deprecated import deprecated -from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, JobsV1Beta3AsyncClient, JobView +from google.cloud.dataflow_v1beta3 import ( + GetJobRequest, + Job, + JobState, + JobsV1Beta3AsyncClient, + JobView, + ListJobMessagesRequest, + MessagesV1Beta3AsyncClient, + MetricsV1Beta3AsyncClient, +) +from google.cloud.dataflow_v1beta3.types import GetJobMetricsRequest, JobMessageImportance, JobMetrics from google.cloud.dataflow_v1beta3.types.jobs import ListJobsRequest from googleapiclient.discovery import build @@ -47,6 +57,8 @@ if TYPE_CHECKING: from google.cloud.dataflow_v1beta3.services.jobs_v1_beta3.pagers import ListJobsAsyncPager + from google.cloud.dataflow_v1beta3.services.messages_v1_beta3.pagers import ListJobMessagesAsyncPager + from google.protobuf.timestamp_pb2 import Timestamp # This is the default location @@ -1353,3 +1365,92 @@ async def list_jobs( ) page_result: ListJobsAsyncPager = await client.list_jobs(request=request) return page_result + + async def list_job_messages( + self, + job_id: str, + project_id: str | None = PROVIDE_PROJECT_ID, + minimum_importance: int = JobMessageImportance.JOB_MESSAGE_BASIC, + page_size: int | None = None, + page_token: str | None = None, + start_time: Timestamp | None = None, + end_time: Timestamp | None = None, + location: str | None = DEFAULT_DATAFLOW_LOCATION, + ) -> ListJobMessagesAsyncPager: + """ + Return ListJobMessagesAsyncPager object from MessagesV1Beta3AsyncClient. + + This method wraps around a similar method of MessagesV1Beta3AsyncClient. ListJobMessagesAsyncPager can be iterated + over to extract messages associated with a specific Job ID. + + For more details see the MessagesV1Beta3AsyncClient method description at: + https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.services.messages_v1_beta3.MessagesV1Beta3AsyncClient + + :param job_id: ID of the Dataflow job to get messages about. + :param project_id: Optional. The Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param minimum_importance: Optional. Filter to only get messages with importance >= level. + For more details see the description at: + https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.types.JobMessageImportance + :param page_size: Optional. If specified, determines the maximum number of messages to return. + If unspecified, the service may choose an appropriate default, or may return an arbitrarily large number of results. + :param page_token: Optional. If supplied, this should be the value of next_page_token returned by an earlier call. + This will cause the next page of results to be returned. + :param start_time: Optional. If specified, return only messages with timestamps >= start_time. + The default is the job creation time (i.e. beginning of messages). + :param end_time: Optional. If specified, return only messages with timestamps < end_time. The default is the current time. + :param location: Optional. The [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains + the job specified by job_id. + """ + project_id = project_id or (await self.get_project_id()) + client = await self.initialize_client(MessagesV1Beta3AsyncClient) + request = ListJobMessagesRequest( + { + "project_id": project_id, + "job_id": job_id, + "minimum_importance": minimum_importance, + "page_size": page_size, + "page_token": page_token, + "start_time": start_time, + "end_time": end_time, + "location": location, + } + ) + page_results: ListJobMessagesAsyncPager = await client.list_job_messages(request=request) + return page_results + + async def get_job_metrics( + self, + job_id: str, + project_id: str | None = PROVIDE_PROJECT_ID, + start_time: Timestamp | None = None, + location: str | None = DEFAULT_DATAFLOW_LOCATION, + ) -> JobMetrics: + """ + Return JobMetrics object from MetricsV1Beta3AsyncClient. + + This method wraps around a similar method of MetricsV1Beta3AsyncClient. + + For more details see the MetricsV1Beta3AsyncClient method description at: + https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.services.metrics_v1_beta3.MetricsV1Beta3AsyncClient + + :param job_id: ID of the Dataflow job to get metrics for. + :param project_id: Optional. The Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param start_time: Optional. Return only metric data that has changed since this time. + Default is to return all information about all metrics for the job. + :param location: Optional. The [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains + the job specified by job_id. + """ + project_id = project_id or (await self.get_project_id()) + client: MetricsV1Beta3AsyncClient = await self.initialize_client(MetricsV1Beta3AsyncClient) + request = GetJobMetricsRequest( + { + "project_id": project_id, + "job_id": job_id, + "start_time": start_time, + "location": location, + } + ) + job_metrics: JobMetrics = await client.get_job_metrics(request=request) + return job_metrics diff --git a/airflow/providers/google/cloud/sensors/dataflow.py b/airflow/providers/google/cloud/sensors/dataflow.py index 10997a35f134..b397d56e2ebf 100644 --- a/airflow/providers/google/cloud/sensors/dataflow.py +++ b/airflow/providers/google/cloud/sensors/dataflow.py @@ -19,14 +19,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any, Callable, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.dataflow import ( DEFAULT_DATAFLOW_LOCATION, DataflowHook, DataflowJobStatus, ) +from airflow.providers.google.cloud.triggers.dataflow import ( + DataflowJobAutoScalingEventTrigger, + DataflowJobMessagesTrigger, + DataflowJobMetricsTrigger, + DataflowJobStatusTrigger, +) from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -42,7 +50,7 @@ class DataflowJobStatusSensor(BaseSensorOperator): :ref:`howto/operator:DataflowJobStatusSensor` :param job_id: ID of the job to be checked. - :param expected_statuses: The expected state of the operation. + :param expected_statuses: The expected state(s) of the operation. See: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.jobs#Job.JobState :param project_id: Optional, the Google Cloud project ID in which to start a job. @@ -58,6 +66,8 @@ class DataflowJobStatusSensor(BaseSensorOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: If True, run the sensor in the deferrable mode. + :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job. """ template_fields: Sequence[str] = ("job_id",) @@ -71,6 +81,8 @@ def __init__( location: str = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 10, **kwargs, ) -> None: super().__init__(**kwargs) @@ -82,7 +94,8 @@ def __init__( self.location = location self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - self.hook: DataflowHook | None = None + self.deferrable = deferrable + self.poll_interval = poll_interval def poke(self, context: Context) -> bool: self.log.info( @@ -90,10 +103,6 @@ def poke(self, context: Context) -> bool: self.job_id, ", ".join(self.expected_statuses), ) - self.hook = DataflowHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) job = self.hook.get_job( job_id=self.job_id, @@ -115,10 +124,51 @@ def poke(self, context: Context) -> bool: return False + def execute(self, context: Context) -> None: + """Airflow runs this method on the worker and defers using the trigger.""" + if not self.deferrable: + super().execute(context) + elif not self.poke(context=context): + self.defer( + timeout=self.execution_timeout, + trigger=DataflowJobStatusTrigger( + job_id=self.job_id, + expected_statuses=self.expected_statuses, + project_id=self.project_id, + location=self.location, + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_interval, + impersonation_chain=self.impersonation_chain, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, str | list]) -> bool: + """ + Execute this method when the task resumes its execution on the worker after deferral. + + Returns True if the trigger returns an event with the success status, otherwise raises + an exception. + """ + if event["status"] == "success": + self.log.info(event["message"]) + return True + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(f"Sensor failed with the following message: {event['message']}.") + raise AirflowException(f"Sensor failed with the following message: {event['message']}") + + @cached_property + def hook(self) -> DataflowHook: + return DataflowHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + class DataflowJobMetricsSensor(BaseSensorOperator): """ - Checks the metrics of a job in Google Cloud Dataflow. + Checks for metrics associated with a single job in Google Cloud Dataflow. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -143,6 +193,9 @@ class DataflowJobMetricsSensor(BaseSensorOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: If True, run the sensor in the deferrable mode. + :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job. + """ template_fields: Sequence[str] = ("job_id",) @@ -151,12 +204,14 @@ def __init__( self, *, job_id: str, - callback: Callable[[dict], bool], + callback: Callable | None = None, fail_on_terminal_state: bool = True, project_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 10, **kwargs, ) -> None: super().__init__(**kwargs) @@ -167,14 +222,10 @@ def __init__( self.location = location self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - self.hook: DataflowHook | None = None + self.deferrable = deferrable + self.poll_interval = poll_interval def poke(self, context: Context) -> bool: - self.hook = DataflowHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - if self.fail_on_terminal_state: job = self.hook.get_job( job_id=self.job_id, @@ -194,27 +245,73 @@ def poke(self, context: Context) -> bool: project_id=self.project_id, location=self.location, ) + return result["metrics"] if self.callback is None else self.callback(result["metrics"]) + + def execute(self, context: Context) -> Any: + """Airflow runs this method on the worker and defers using the trigger.""" + if not self.deferrable: + super().execute(context) + else: + self.defer( + timeout=self.execution_timeout, + trigger=DataflowJobMetricsTrigger( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_interval, + impersonation_chain=self.impersonation_chain, + fail_on_terminal_state=self.fail_on_terminal_state, + ), + method_name="execute_complete", + ) - return self.callback(result["metrics"]) + def execute_complete(self, context: Context, event: dict[str, str | list]) -> Any: + """ + Execute this method when the task resumes its execution on the worker after deferral. + + If the trigger returns an event with success status - passes the event result to the callback function. + Returns the event result if no callback function is provided. + + If the trigger returns an event with error status - raises an exception. + """ + if event["status"] == "success": + self.log.info(event["message"]) + return event["result"] if self.callback is None else self.callback(event["result"]) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(f"Sensor failed with the following message: {event['message']}.") + raise AirflowException(f"Sensor failed with the following message: {event['message']}") + + @cached_property + def hook(self) -> DataflowHook: + return DataflowHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) class DataflowJobMessagesSensor(BaseSensorOperator): """ - Checks for the job message in Google Cloud Dataflow. + Checks for job messages associated with a single job in Google Cloud Dataflow. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:DataflowJobMessagesSensor` - :param job_id: ID of the job to be checked. - :param callback: callback which is called with list of read job metrics - See: - https://cloud.google.com/dataflow/docs/reference/rest/v1b3/MetricUpdate - :param fail_on_terminal_state: If set to true sensor will raise Exception when - job is in terminal state + :param job_id: ID of the Dataflow job to be checked. + :param callback: a function that can accept a list of serialized job messages. + It can do whatever you want it to do. If the callback function is not provided, + then on successful completion the task will exit with True value. + For more info about the job message content see: + https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.types.JobMessage + :param fail_on_terminal_state: If set to True the sensor will raise an exception when the job reaches a terminal state. + No job messages will be returned. :param project_id: Optional, the Google Cloud project ID in which to start a job. If set to None or missing, the default project_id from the Google Cloud connection is used. - :param location: Job location. + :param location: The location of the Dataflow job (for example europe-west1). + If set to None then the value of DEFAULT_DATAFLOW_LOCATION will be used. + See: https://cloud.google.com/dataflow/docs/concepts/regional-endpoints :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token @@ -224,6 +321,8 @@ class DataflowJobMessagesSensor(BaseSensorOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: If True, run the sensor in the deferrable mode. + :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job. """ template_fields: Sequence[str] = ("job_id",) @@ -232,12 +331,14 @@ def __init__( self, *, job_id: str, - callback: Callable, + callback: Callable | None = None, fail_on_terminal_state: bool = True, project_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 10, **kwargs, ) -> None: super().__init__(**kwargs) @@ -248,14 +349,10 @@ def __init__( self.location = location self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - self.hook: DataflowHook | None = None + self.deferrable = deferrable + self.poll_interval = poll_interval def poke(self, context: Context) -> bool: - self.hook = DataflowHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - if self.fail_on_terminal_state: job = self.hook.get_job( job_id=self.job_id, @@ -276,26 +373,73 @@ def poke(self, context: Context) -> bool: location=self.location, ) - return self.callback(result) + return result if self.callback is None else self.callback(result) + + def execute(self, context: Context) -> Any: + """Airflow runs this method on the worker and defers using the trigger.""" + if not self.deferrable: + super().execute(context) + else: + self.defer( + timeout=self.execution_timeout, + trigger=DataflowJobMessagesTrigger( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_interval, + impersonation_chain=self.impersonation_chain, + fail_on_terminal_state=self.fail_on_terminal_state, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, str | list]) -> Any: + """ + Execute this method when the task resumes its execution on the worker after deferral. + + If the trigger returns an event with success status - passes the event result to the callback function. + Returns the event result if no callback function is provided. + + If the trigger returns an event with error status - raises an exception. + """ + if event["status"] == "success": + self.log.info(event["message"]) + return event["result"] if self.callback is None else self.callback(event["result"]) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(f"Sensor failed with the following message: {event['message']}.") + raise AirflowException(f"Sensor failed with the following message: {event['message']}") + + @cached_property + def hook(self) -> DataflowHook: + return DataflowHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) class DataflowJobAutoScalingEventsSensor(BaseSensorOperator): """ - Checks for the job autoscaling event in Google Cloud Dataflow. + Checks for autoscaling events associated with a single job in Google Cloud Dataflow. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:DataflowJobAutoScalingEventsSensor` - :param job_id: ID of the job to be checked. - :param callback: callback which is called with list of read job metrics - See: - https://cloud.google.com/dataflow/docs/reference/rest/v1b3/MetricUpdate - :param fail_on_terminal_state: If set to true sensor will raise Exception when - job is in terminal state + :param job_id: ID of the Dataflow job to be checked. + :param callback: a function that can accept a list of serialized autoscaling events. + It can do whatever you want it to do. If the callback function is not provided, + then on successful completion the task will exit with True value. + For more info about the autoscaling event content see: + https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.types.AutoscalingEvent + :param fail_on_terminal_state: If set to True the sensor will raise an exception when the job reaches a terminal state. + No autoscaling events will be returned. :param project_id: Optional, the Google Cloud project ID in which to start a job. If set to None or missing, the default project_id from the Google Cloud connection is used. - :param location: Job location. + :param location: The location of the Dataflow job (for example europe-west1). + If set to None then the value of DEFAULT_DATAFLOW_LOCATION will be used. + See: https://cloud.google.com/dataflow/docs/concepts/regional-endpoints :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token @@ -305,6 +449,8 @@ class DataflowJobAutoScalingEventsSensor(BaseSensorOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: If True, run the sensor in the deferrable mode. + :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job. """ template_fields: Sequence[str] = ("job_id",) @@ -313,12 +459,14 @@ def __init__( self, *, job_id: str, - callback: Callable, + callback: Callable | None = None, fail_on_terminal_state: bool = True, project_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 60, **kwargs, ) -> None: super().__init__(**kwargs) @@ -329,14 +477,10 @@ def __init__( self.location = location self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - self.hook: DataflowHook | None = None + self.deferrable = deferrable + self.poll_interval = poll_interval def poke(self, context: Context) -> bool: - self.hook = DataflowHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - if self.fail_on_terminal_state: job = self.hook.get_job( job_id=self.job_id, @@ -357,4 +501,46 @@ def poke(self, context: Context) -> bool: location=self.location, ) - return self.callback(result) + return result if self.callback is None else self.callback(result) + + def execute(self, context: Context) -> Any: + """Airflow runs this method on the worker and defers using the trigger.""" + if not self.deferrable: + super().execute(context) + else: + self.defer( + trigger=DataflowJobAutoScalingEventTrigger( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_interval, + impersonation_chain=self.impersonation_chain, + fail_on_terminal_state=self.fail_on_terminal_state, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, str | list]) -> Any: + """ + Execute this method when the task resumes its execution on the worker after deferral. + + If the trigger returns an event with success status - passes the event result to the callback function. + Returns the event result if no callback function is provided. + + If the trigger returns an event with error status - raises an exception. + """ + if event["status"] == "success": + self.log.info(event["message"]) + return event["result"] if self.callback is None else self.callback(event["result"]) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(f"Sensor failed with the following message: {event['message']}.") + raise AirflowException(f"Sensor failed with the following message: {event['message']}") + + @cached_property + def hook(self) -> DataflowHook: + return DataflowHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) diff --git a/airflow/providers/google/cloud/triggers/dataflow.py b/airflow/providers/google/cloud/triggers/dataflow.py index c752b79978dd..32f68a9fd703 100644 --- a/airflow/providers/google/cloud/triggers/dataflow.py +++ b/airflow/providers/google/cloud/triggers/dataflow.py @@ -18,13 +18,24 @@ from __future__ import annotations import asyncio -from typing import Any, Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any, Sequence from google.cloud.dataflow_v1beta3 import JobState +from google.cloud.dataflow_v1beta3.types import ( + AutoscalingEvent, + JobMessage, + JobMetrics, + MetricUpdate, +) -from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook +from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook, DataflowJobStatus from airflow.triggers.base import BaseTrigger, TriggerEvent +if TYPE_CHECKING: + from google.cloud.dataflow_v1beta3.services.messages_v1_beta3.pagers import ListJobMessagesAsyncPager + + DEFAULT_DATAFLOW_LOCATION = "us-central1" @@ -59,7 +70,6 @@ def __init__( cancel_timeout: int | None = 5 * 60, ): super().__init__() - self.project_id = project_id self.job_id = job_id self.location = location @@ -142,3 +152,493 @@ def _get_async_hook(self) -> AsyncDataflowHook: impersonation_chain=self.impersonation_chain, cancel_timeout=self.cancel_timeout, ) + + +class DataflowJobStatusTrigger(BaseTrigger): + """ + Trigger that checks for metrics associated with a Dataflow job. + + :param job_id: Required. ID of the job. + :param expected_statuses: The expected state(s) of the operation. + See: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.jobs#Job.JobState + :param project_id: Required. The Google Cloud project ID in which the job was started. + :param location: Optional. The location where the job is executed. If set to None then + the value of DEFAULT_DATAFLOW_LOCATION will be used. + :param gcp_conn_id: The connection ID to use for connecting to Google Cloud. + :param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + def __init__( + self, + job_id: str, + expected_statuses: set[str], + project_id: str | None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + poll_sleep: int = 10, + impersonation_chain: str | Sequence[str] | None = None, + ): + super().__init__() + self.job_id = job_id + self.expected_statuses = expected_statuses + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.poll_sleep = poll_sleep + self.impersonation_chain = impersonation_chain + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize class arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobStatusTrigger", + { + "job_id": self.job_id, + "expected_statuses": self.expected_statuses, + "project_id": self.project_id, + "location": self.location, + "gcp_conn_id": self.gcp_conn_id, + "poll_sleep": self.poll_sleep, + "impersonation_chain": self.impersonation_chain, + }, + ) + + async def run(self): + """ + Loop until the job reaches an expected or terminal state. + + Yields a TriggerEvent with success status, if the client returns an expected job status. + + Yields a TriggerEvent with error status, if the client returns an unexpected terminal + job status or any exception is raised while looping. + + In any other case the Trigger will wait for a specified amount of time + stored in self.poll_sleep variable. + """ + try: + while True: + job_status = await self.async_hook.get_job_status( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + if job_status.name in self.expected_statuses: + yield TriggerEvent( + { + "status": "success", + "message": f"Job with id '{self.job_id}' has reached an expected state: {job_status.name}", + } + ) + return + elif job_status.name in DataflowJobStatus.TERMINAL_STATES: + yield TriggerEvent( + { + "status": "error", + "message": f"Job with id '{self.job_id}' is already in terminal state: {job_status.name}", + } + ) + return + self.log.info("Sleeping for %s seconds.", self.poll_sleep) + await asyncio.sleep(self.poll_sleep) + except Exception as e: + self.log.error("Exception occurred while checking for job status!") + yield TriggerEvent( + { + "status": "error", + "message": str(e), + } + ) + + @cached_property + def async_hook(self) -> AsyncDataflowHook: + return AsyncDataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + ) + + +class DataflowJobMetricsTrigger(BaseTrigger): + """ + Trigger that checks for metrics associated with a Dataflow job. + + :param job_id: Required. ID of the job. + :param project_id: Required. The Google Cloud project ID in which the job was started. + :param location: Optional. The location where the job is executed. If set to None then + the value of DEFAULT_DATAFLOW_LOCATION will be used. + :param gcp_conn_id: The connection ID to use for connecting to Google Cloud. + :param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param fail_on_terminal_state: If set to True the trigger will yield a TriggerEvent with + error status if the job reaches a terminal state. + """ + + def __init__( + self, + job_id: str, + project_id: str | None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + poll_sleep: int = 10, + impersonation_chain: str | Sequence[str] | None = None, + fail_on_terminal_state: bool = True, + ): + super().__init__() + self.project_id = project_id + self.job_id = job_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.poll_sleep = poll_sleep + self.impersonation_chain = impersonation_chain + self.fail_on_terminal_state = fail_on_terminal_state + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize class arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobMetricsTrigger", + { + "project_id": self.project_id, + "job_id": self.job_id, + "location": self.location, + "gcp_conn_id": self.gcp_conn_id, + "poll_sleep": self.poll_sleep, + "impersonation_chain": self.impersonation_chain, + "fail_on_terminal_state": self.fail_on_terminal_state, + }, + ) + + async def run(self): + """ + Loop until a terminal job status or any job metrics are returned. + + Yields a TriggerEvent with success status, if the client returns any job metrics + and fail_on_terminal_state attribute is False. + + Yields a TriggerEvent with error status, if the client returns a job status with + a terminal state value and fail_on_terminal_state attribute is True. + + Yields a TriggerEvent with error status, if any exception is raised while looping. + + In any other case the Trigger will wait for a specified amount of time + stored in self.poll_sleep variable. + """ + try: + while True: + job_status = await self.async_hook.get_job_status( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + job_metrics = await self.get_job_metrics() + if self.fail_on_terminal_state and job_status.name in DataflowJobStatus.TERMINAL_STATES: + yield TriggerEvent( + { + "status": "error", + "message": f"Job with id '{self.job_id}' is already in terminal state: {job_status.name}", + "result": None, + } + ) + return + if job_metrics: + yield TriggerEvent( + { + "status": "success", + "message": f"Detected {len(job_metrics)} metrics for job '{self.job_id}'", + "result": job_metrics, + } + ) + return + self.log.info("Sleeping for %s seconds.", self.poll_sleep) + await asyncio.sleep(self.poll_sleep) + except Exception as e: + self.log.error("Exception occurred while checking for job's metrics!") + yield TriggerEvent({"status": "error", "message": str(e), "result": None}) + + async def get_job_metrics(self) -> list[dict[str, Any]]: + """Wait for the Dataflow client response and then return it in a serialized list.""" + job_response: JobMetrics = await self.async_hook.get_job_metrics( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + return self._get_metrics_from_job_response(job_response) + + def _get_metrics_from_job_response(self, job_response: JobMetrics) -> list[dict[str, Any]]: + """Return a list of serialized MetricUpdate objects.""" + return [MetricUpdate.to_dict(metric) for metric in job_response.metrics] + + @cached_property + def async_hook(self) -> AsyncDataflowHook: + return AsyncDataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + ) + + +class DataflowJobAutoScalingEventTrigger(BaseTrigger): + """ + Trigger that checks for autoscaling events associated with a Dataflow job. + + :param job_id: Required. ID of the job. + :param project_id: Required. The Google Cloud project ID in which the job was started. + :param location: Optional. The location where the job is executed. If set to None then + the value of DEFAULT_DATAFLOW_LOCATION will be used. + :param gcp_conn_id: The connection ID to use for connecting to Google Cloud. + :param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param fail_on_terminal_state: If set to True the trigger will yield a TriggerEvent with + error status if the job reaches a terminal state. + """ + + def __init__( + self, + job_id: str, + project_id: str | None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + poll_sleep: int = 10, + impersonation_chain: str | Sequence[str] | None = None, + fail_on_terminal_state: bool = True, + ): + super().__init__() + self.project_id = project_id + self.job_id = job_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.poll_sleep = poll_sleep + self.impersonation_chain = impersonation_chain + self.fail_on_terminal_state = fail_on_terminal_state + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize class arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobAutoScalingEventTrigger", + { + "project_id": self.project_id, + "job_id": self.job_id, + "location": self.location, + "gcp_conn_id": self.gcp_conn_id, + "poll_sleep": self.poll_sleep, + "impersonation_chain": self.impersonation_chain, + "fail_on_terminal_state": self.fail_on_terminal_state, + }, + ) + + async def run(self): + """ + Loop until a terminal job status or any autoscaling events are returned. + + Yields a TriggerEvent with success status, if the client returns any autoscaling events + and fail_on_terminal_state attribute is False. + + Yields a TriggerEvent with error status, if the client returns a job status with + a terminal state value and fail_on_terminal_state attribute is True. + + Yields a TriggerEvent with error status, if any exception is raised while looping. + + In any other case the Trigger will wait for a specified amount of time + stored in self.poll_sleep variable. + """ + try: + while True: + job_status = await self.async_hook.get_job_status( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + autoscaling_events = await self.list_job_autoscaling_events() + if self.fail_on_terminal_state and job_status.name in DataflowJobStatus.TERMINAL_STATES: + yield TriggerEvent( + { + "status": "error", + "message": f"Job with id '{self.job_id}' is already in terminal state: {job_status.name}", + "result": None, + } + ) + return + if autoscaling_events: + yield TriggerEvent( + { + "status": "success", + "message": f"Detected {len(autoscaling_events)} autoscaling events for job '{self.job_id}'", + "result": autoscaling_events, + } + ) + return + self.log.info("Sleeping for %s seconds.", self.poll_sleep) + await asyncio.sleep(self.poll_sleep) + except Exception as e: + self.log.error("Exception occurred while checking for job's autoscaling events!") + yield TriggerEvent({"status": "error", "message": str(e), "result": None}) + + async def list_job_autoscaling_events(self) -> list[dict[str, str | dict]]: + """Wait for the Dataflow client response and then return it in a serialized list.""" + job_response: ListJobMessagesAsyncPager = await self.async_hook.list_job_messages( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + return self._get_autoscaling_events_from_job_response(job_response) + + def _get_autoscaling_events_from_job_response( + self, job_response: ListJobMessagesAsyncPager + ) -> list[dict[str, str | dict]]: + """Return a list of serialized AutoscalingEvent objects.""" + return [AutoscalingEvent.to_dict(event) for event in job_response.autoscaling_events] + + @cached_property + def async_hook(self) -> AsyncDataflowHook: + return AsyncDataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + ) + + +class DataflowJobMessagesTrigger(BaseTrigger): + """ + Trigger that checks for job messages associated with a Dataflow job. + + :param job_id: Required. ID of the job. + :param project_id: Required. The Google Cloud project ID in which the job was started. + :param location: Optional. The location where the job is executed. If set to None then + the value of DEFAULT_DATAFLOW_LOCATION will be used. + :param gcp_conn_id: The connection ID to use for connecting to Google Cloud. + :param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param fail_on_terminal_state: If set to True the trigger will yield a TriggerEvent with + error status if the job reaches a terminal state. + """ + + def __init__( + self, + job_id: str, + project_id: str | None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + poll_sleep: int = 10, + impersonation_chain: str | Sequence[str] | None = None, + fail_on_terminal_state: bool = True, + ): + super().__init__() + self.project_id = project_id + self.job_id = job_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.poll_sleep = poll_sleep + self.impersonation_chain = impersonation_chain + self.fail_on_terminal_state = fail_on_terminal_state + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize class arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobMessagesTrigger", + { + "project_id": self.project_id, + "job_id": self.job_id, + "location": self.location, + "gcp_conn_id": self.gcp_conn_id, + "poll_sleep": self.poll_sleep, + "impersonation_chain": self.impersonation_chain, + "fail_on_terminal_state": self.fail_on_terminal_state, + }, + ) + + async def run(self): + """ + Loop until a terminal job status or any job messages are returned. + + Yields a TriggerEvent with success status, if the client returns any job messages + and fail_on_terminal_state attribute is False. + + Yields a TriggerEvent with error status, if the client returns a job status with + a terminal state value and fail_on_terminal_state attribute is True. + + Yields a TriggerEvent with error status, if any exception is raised while looping. + + In any other case the Trigger will wait for a specified amount of time + stored in self.poll_sleep variable. + """ + try: + while True: + job_status = await self.async_hook.get_job_status( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + job_messages = await self.list_job_messages() + if self.fail_on_terminal_state and job_status.name in DataflowJobStatus.TERMINAL_STATES: + yield TriggerEvent( + { + "status": "error", + "message": f"Job with id '{self.job_id}' is already in terminal state: {job_status.name}", + "result": None, + } + ) + return + if job_messages: + yield TriggerEvent( + { + "status": "success", + "message": f"Detected {len(job_messages)} job messages for job '{self.job_id}'", + "result": job_messages, + } + ) + return + self.log.info("Sleeping for %s seconds.", self.poll_sleep) + await asyncio.sleep(self.poll_sleep) + except Exception as e: + self.log.error("Exception occurred while checking for job's messages!") + yield TriggerEvent({"status": "error", "message": str(e), "result": None}) + + async def list_job_messages(self) -> list[dict[str, str | dict]]: + """Wait for the Dataflow client response and then return it in a serialized list.""" + job_response: ListJobMessagesAsyncPager = await self.async_hook.list_job_messages( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + return self._get_job_messages_from_job_response(job_response) + + def _get_job_messages_from_job_response( + self, job_response: ListJobMessagesAsyncPager + ) -> list[dict[str, str | dict]]: + """Return a list of serialized JobMessage objects.""" + return [JobMessage.to_dict(message) for message in job_response.job_messages] + + @cached_property + def async_hook(self) -> AsyncDataflowHook: + return AsyncDataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + ) diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst index 905bb5790bca..d3f1bd6df4b6 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst @@ -288,6 +288,14 @@ When job is triggered asynchronously sensors may be used to run checks for speci :start-after: [START howto_sensor_wait_for_job_status] :end-before: [END howto_sensor_wait_for_job_status] +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_sensors_deferrable.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_wait_for_job_status_deferrable] + :end-before: [END howto_sensor_wait_for_job_status_deferrable] + :class:`~airflow.providers.google.cloud.sensors.dataflow.DataflowJobMetricsSensor`. .. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_native_python_async.py @@ -296,6 +304,14 @@ When job is triggered asynchronously sensors may be used to run checks for speci :start-after: [START howto_sensor_wait_for_job_metric] :end-before: [END howto_sensor_wait_for_job_metric] +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_sensors_deferrable.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_wait_for_job_metric_deferrable] + :end-before: [END howto_sensor_wait_for_job_metric_deferrable] + :class:`~airflow.providers.google.cloud.sensors.dataflow.DataflowJobMessagesSensor`. .. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_native_python_async.py @@ -304,6 +320,14 @@ When job is triggered asynchronously sensors may be used to run checks for speci :start-after: [START howto_sensor_wait_for_job_message] :end-before: [END howto_sensor_wait_for_job_message] +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_sensors_deferrable.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_wait_for_job_message_deferrable] + :end-before: [END howto_sensor_wait_for_job_message_deferrable] + :class:`~airflow.providers.google.cloud.sensors.dataflow.DataflowJobAutoScalingEventsSensor`. .. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_native_python_async.py @@ -312,6 +336,14 @@ When job is triggered asynchronously sensors may be used to run checks for speci :start-after: [START howto_sensor_wait_for_job_autoscaling_event] :end-before: [END howto_sensor_wait_for_job_autoscaling_event] +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_sensors_deferrable.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_wait_for_job_autoscaling_event_deferrable] + :end-before: [END howto_sensor_wait_for_job_autoscaling_event_deferrable] + Reference ^^^^^^^^^ diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index 0ea2ce3808ec..2458b48e8143 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -29,7 +29,14 @@ from uuid import UUID import pytest -from google.cloud.dataflow_v1beta3 import GetJobRequest, JobView, ListJobsRequest +from google.cloud.dataflow_v1beta3 import ( + GetJobMetricsRequest, + GetJobRequest, + JobView, + ListJobMessagesRequest, + ListJobsRequest, +) +from google.cloud.dataflow_v1beta3.types import JobMessageImportance from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.apache.beam.hooks.beam import BeamHook, run_beam_command @@ -1964,7 +1971,7 @@ def func(mock_obj, return_value): return func -class TestAsyncHook: +class TestAsyncDataflowHook: def test_delegate_to_runtime_error(self): with pytest.raises(RuntimeError): AsyncDataflowHook(gcp_conn_id="GCP_CONN_ID", delegate_to="delegate_to") @@ -2023,3 +2030,47 @@ async def test_list_jobs(self, initialize_client_mock, hook, make_mock_awaitable ) initialize_client_mock.assert_called_once() client.list_jobs.assert_called_once_with(request=request) + + @pytest.mark.asyncio + @mock.patch(DATAFLOW_STRING.format("AsyncDataflowHook.initialize_client")) + async def test_list_job_messages(self, initialize_client_mock, hook): + client = initialize_client_mock.return_value + await hook.list_job_messages( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + job_id=TEST_JOB_ID, + ) + request = ListJobMessagesRequest( + { + "project_id": TEST_PROJECT_ID, + "job_id": TEST_JOB_ID, + "minimum_importance": JobMessageImportance.JOB_MESSAGE_BASIC, + "page_size": None, + "page_token": None, + "start_time": None, + "end_time": None, + "location": TEST_LOCATION, + } + ) + initialize_client_mock.assert_called_once() + client.list_job_messages.assert_called_once_with(request=request) + + @pytest.mark.asyncio + @mock.patch(DATAFLOW_STRING.format("AsyncDataflowHook.initialize_client")) + async def test_get_job_metrics(self, initialize_client_mock, hook): + client = initialize_client_mock.return_value + await hook.get_job_metrics( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + job_id=TEST_JOB_ID, + ) + request = GetJobMetricsRequest( + { + "project_id": TEST_PROJECT_ID, + "job_id": TEST_JOB_ID, + "start_time": None, + "location": TEST_LOCATION, + } + ) + initialize_client_mock.assert_called_once() + client.get_job_metrics.assert_called_once_with(request=request) diff --git a/tests/providers/google/cloud/sensors/test_dataflow.py b/tests/providers/google/cloud/sensors/test_dataflow.py index d669b2b11190..519ace52a4a1 100644 --- a/tests/providers/google/cloud/sensors/test_dataflow.py +++ b/tests/providers/google/cloud/sensors/test_dataflow.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus from airflow.providers.google.cloud.sensors.dataflow import ( DataflowJobAutoScalingEventsSensor, @@ -29,6 +29,12 @@ DataflowJobMetricsSensor, DataflowJobStatusSensor, ) +from airflow.providers.google.cloud.triggers.dataflow import ( + DataflowJobAutoScalingEventTrigger, + DataflowJobMessagesTrigger, + DataflowJobMetricsTrigger, + DataflowJobStatusTrigger, +) TEST_TASK_ID = "task_id" TEST_JOB_ID = "test_job_id" @@ -104,6 +110,76 @@ def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION ) + @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowJobStatusSensor.poke") + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook") + def test_execute_enters_deferred_state(self, mock_hook, mock_poke): + """ + Tests that DataflowJobStatusTrigger will be fired when the DataflowJobStatusSensor + is executed and deferrable is set to True. + """ + task = DataflowJobStatusSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + expected_statuses=DataflowJobStatus.JOB_STATE_DONE, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + ) + mock_hook.return_value.exists.return_value = False + mock_poke.return_value = False + with pytest.raises(TaskDeferred) as exc: + task.execute(None) + assert isinstance( + exc.value.trigger, DataflowJobStatusTrigger + ), "Trigger is not a DataflowJobStatusTrigger" + + def test_execute_complete_success(self): + """Tests that the trigger event contains expected values if no callback function is provided.""" + expected_result = True + task = DataflowJobStatusSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + expected_statuses=DataflowJobStatus.JOB_STATE_DONE, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + ) + actual_message = task.execute_complete( + context=None, + event={ + "status": "success", + "message": f"Job with id '{TEST_JOB_ID}' has reached an expected state: {DataflowJobStatus.JOB_STATE_DONE}", + }, + ) + assert actual_message == expected_result + + @pytest.mark.parametrize( + "expected_exception, soft_fail", + ( + (AirflowException, False), + (AirflowSkipException, True), + ), + ) + def test_execute_complete_not_success_status_raises_exception(self, expected_exception, soft_fail): + """Tests that AirflowException or AirflowSkipException is raised if the trigger event contains an error.""" + task = DataflowJobStatusSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + location=TEST_LOCATION, + expected_statuses=DataflowJobStatus.JOB_STATE_DONE, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + soft_fail=soft_fail, + ) + with pytest.raises(expected_exception): + task.execute_complete(context=None, event={"status": "error", "message": "test error message"}) + class TestDataflowJobMetricsSensor: @pytest.mark.parametrize( @@ -145,8 +221,158 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state): mock_fetch_job_metrics_by_id.return_value.__getitem__.assert_called_once_with("metrics") callback.assert_called_once_with(mock_fetch_job_metrics_by_id.return_value.__getitem__.return_value) + @pytest.mark.parametrize( + "soft_fail, expected_exception", + ((False, AirflowException), (True, AirflowSkipException)), + ) + @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook") + def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): + mock_get_job = mock_hook.return_value.get_job + mock_fetch_job_messages_by_id = mock_hook.return_value.fetch_job_messages_by_id + callback = mock.MagicMock() + + task = DataflowJobMetricsSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + callback=callback, + fail_on_terminal_state=True, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + soft_fail=soft_fail, + ) + mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_DONE} + + with pytest.raises( + expected_exception, + match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: " + f"{DataflowJobStatus.JOB_STATE_DONE}", + ): + task.poke(mock.MagicMock()) + + mock_hook.assert_called_once_with( + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + mock_fetch_job_messages_by_id.assert_not_called() + callback.assert_not_called() + + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook") + def test_execute_enters_deferred_state(self, mock_hook): + """ + Tests that DataflowJobMetricsTrigger will be fired when the DataflowJobMetricsSensor + is executed and deferrable is set to True. + """ + task = DataflowJobMetricsSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + callback=None, + ) + mock_hook.return_value.exists.return_value = False + with pytest.raises(TaskDeferred) as exc: + task.execute(None) + assert isinstance( + exc.value.trigger, DataflowJobMetricsTrigger + ), "Trigger is not a DataflowJobMetricsTrigger" + + def test_execute_complete_success_without_callback_function(self): + """Tests that the trigger event contains expected values if no callback function is provided.""" + expected_result = [] + task = DataflowJobMetricsSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + callback=None, + ) + actual_message = task.execute_complete( + context=None, + event={ + "status": "success", + "message": f"Detected 2 metrics for job '{TEST_JOB_ID}'", + "result": [], + }, + ) + assert actual_message == expected_result + + def test_execute_complete_success_with_callback_function(self): + """Tests that the trigger event contains expected values if the callback function is provided.""" + expected_result = [ + { + "name": {"origin": "", "name": "", "context": {}}, + "scalar": 0.0, + "update_time": "2024-03-20T12:36:05.229Z", + "kind": "", + "cumulative": False, + }, + { + "name": {"origin": "", "name": "", "context": {}}, + "scalar": 0.0, + "update_time": "2024-03-20T12:36:05.229Z", + "kind": "", + "cumulative": False, + }, + ] + task = DataflowJobMetricsSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + callback=lambda res: res, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + ) + actual_result = task.execute_complete( + context=None, + event={ + "status": "success", + "message": f"Detected 2 job messages for job '{TEST_JOB_ID}'", + "result": expected_result, + }, + ) + assert actual_result == expected_result -class DataflowJobMessagesSensorTest: + @pytest.mark.parametrize( + "expected_exception, soft_fail", + ( + (AirflowException, False), + (AirflowSkipException, True), + ), + ) + def test_execute_complete_not_success_status_raises_exception(self, expected_exception, soft_fail): + """Tests that AirflowException or AirflowSkipException is raised if the trigger event contains an error.""" + task = DataflowJobMetricsSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + callback=None, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + soft_fail=soft_fail, + ) + with pytest.raises(expected_exception): + task.execute_complete( + context=None, event={"status": "error", "message": "test error message", "result": None} + ) + + +class TestDataflowJobMessagesSensor: @pytest.mark.parametrize( "job_current_state, fail_on_terminal_state", [ @@ -187,7 +413,8 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state): callback.assert_called_once_with(mock_fetch_job_messages_by_id.return_value) @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + "soft_fail, expected_exception", + ((False, AirflowException), (True, AirflowSkipException)), ) @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook") def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): @@ -222,8 +449,119 @@ def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): mock_fetch_job_messages_by_id.assert_not_called() callback.assert_not_called() + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook") + def test_execute_enters_deferred_state(self, mock_hook): + """ + Tests that DataflowJobMessagesTrigger will be fired when the DataflowJobMessagesSensor + is executed and deferrable is set to True. + """ + task = DataflowJobMessagesSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + callback=None, + ) + mock_hook.return_value.exists.return_value = False + with pytest.raises(TaskDeferred) as exc: + task.execute(None) + assert isinstance( + exc.value.trigger, DataflowJobMessagesTrigger + ), "Trigger is not a DataflowJobMessagesTrigger" + + def test_execute_complete_success_without_callback_function(self): + """Tests that the trigger event contains expected values if no callback function is provided.""" + expected_result = [] + task = DataflowJobMessagesSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + callback=None, + ) + actual_message = task.execute_complete( + context=None, + event={ + "status": "success", + "message": f"Detected 2 job messages for job '{TEST_JOB_ID}'", + "result": [], + }, + ) + assert actual_message == expected_result + + def test_execute_complete_success_with_callback_function(self): + """Tests that the trigger event contains expected values if the callback function is provided.""" + expected_result = [ + { + "id": "1707695235850", + "time": "2024-02-06T23:47:15.850Z", + "message_text": "msg.", + "message_importance": 5, + }, + { + "id": "1707695635401", + "time": "2024-02-06T23:53:55.401Z", + "message_text": "msg.", + "message_importance": 5, + }, + ] + task = DataflowJobMessagesSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + callback=lambda res: res, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + ) + actual_result = task.execute_complete( + context=None, + event={ + "status": "success", + "message": f"Detected 2 job messages for job '{TEST_JOB_ID}'", + "result": expected_result, + }, + ) + assert actual_result == expected_result + + @pytest.mark.parametrize( + "expected_exception, soft_fail", + ( + (AirflowException, False), + (AirflowSkipException, True), + ), + ) + def test_execute_complete_not_success_status_raises_exception(self, expected_exception, soft_fail): + """Tests that AirflowException or AirflowSkipException is raised if the trigger event contains an error.""" + task = DataflowJobMessagesSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + callback=None, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + soft_fail=soft_fail, + ) + with pytest.raises(expected_exception): + task.execute_complete( + context=None, event={"status": "error", "message": "test error message", "result": None} + ) + -class DataflowJobAutoScalingEventsSensorTest: +class TestDataflowJobAutoScalingEventsSensor: @pytest.mark.parametrize( "job_current_state, fail_on_terminal_state", [ @@ -264,7 +602,11 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state): callback.assert_called_once_with(mock_fetch_job_autoscaling_events_by_id.return_value) @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + "soft_fail, expected_exception", + ( + (False, AirflowException), + (True, AirflowSkipException), + ), ) @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook") def test_poke_raise_exception_on_terminal_state(self, mock_hook, soft_fail, expected_exception): @@ -298,3 +640,113 @@ def test_poke_raise_exception_on_terminal_state(self, mock_hook, soft_fail, expe ) mock_fetch_job_autoscaling_events_by_id.assert_not_called() callback.assert_not_called() + + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook") + def test_execute_enters_deferred_state(self, mock_hook): + """ + Tests that AutoScalingEventTrigger will be fired when the DataflowJobAutoScalingEventSensor + is executed and deferrable is set to True. + """ + task = DataflowJobAutoScalingEventsSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + callback=None, + ) + mock_hook.return_value.exists.return_value = False + with pytest.raises(TaskDeferred) as exc: + task.execute(None) + assert isinstance( + exc.value.trigger, DataflowJobAutoScalingEventTrigger + ), "Trigger is not a DataflowJobAutoScalingEventTrigger" + + def test_execute_complete_success_without_callback_function(self): + """Tests that the trigger event contains expected values if no callback function is provided.""" + expected_result = [] + task = DataflowJobAutoScalingEventsSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + callback=None, + ) + actual_message = task.execute_complete( + context=None, + event={ + "status": "success", + "message": f"Detected 2 autoscaling events for job '{TEST_JOB_ID}'", + "result": [], + }, + ) + assert actual_message == expected_result + + def test_execute_complete_success_with_callback_function(self): + """Tests that the trigger event contains expected values if the callback function is provided.""" + expected_result = [ + { + "event_type": 2, + "description": {}, + "time": "2024-02-05T13:43:31.066611771Z", + }, + { + "event_type": 1, + "description": {}, + "time": "2024-02-05T13:43:31.066611771Z", + }, + ] + task = DataflowJobAutoScalingEventsSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + callback=lambda res: res, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + ) + actual_result = task.execute_complete( + context=None, + event={ + "status": "success", + "message": f"Detected 2 autoscaling events for job '{TEST_JOB_ID}'", + "result": expected_result, + }, + ) + assert actual_result == expected_result + + @pytest.mark.parametrize( + "expected_exception, soft_fail", + ( + (AirflowException, False), + (AirflowSkipException, True), + ), + ) + def test_execute_complete_not_success_status_raises_exception(self, expected_exception, soft_fail): + """Tests that AirflowException or AirflowSkipException is raised if the trigger event contains an error.""" + task = DataflowJobAutoScalingEventsSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + callback=None, + fail_on_terminal_state=False, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + soft_fail=soft_fail, + ) + with pytest.raises(expected_exception): + task.execute_complete( + context=None, + event={"status": "error", "message": "test error message", "result": None}, + ) diff --git a/tests/providers/google/cloud/triggers/test_dataflow.py b/tests/providers/google/cloud/triggers/test_dataflow.py index eb4889aabfeb..2b9b63afa44e 100644 --- a/tests/providers/google/cloud/triggers/test_dataflow.py +++ b/tests/providers/google/cloud/triggers/test_dataflow.py @@ -24,7 +24,14 @@ import pytest from google.cloud.dataflow_v1beta3 import JobState -from airflow.providers.google.cloud.triggers.dataflow import TemplateJobStartTrigger +from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus +from airflow.providers.google.cloud.triggers.dataflow import ( + DataflowJobAutoScalingEventTrigger, + DataflowJobMessagesTrigger, + DataflowJobMetricsTrigger, + DataflowJobStatusTrigger, + TemplateJobStartTrigger, +) from airflow.triggers.base import TriggerEvent PROJECT_ID = "test-project-id" @@ -37,7 +44,7 @@ @pytest.fixture -def trigger(): +def template_job_start_trigger(): return TemplateJobStartTrigger( project_id=PROJECT_ID, job_id=JOB_ID, @@ -49,9 +56,61 @@ def trigger(): ) +@pytest.fixture +def dataflow_job_autoscaling_event_trigger(): + return DataflowJobAutoScalingEventTrigger( + project_id=PROJECT_ID, + job_id=JOB_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + poll_sleep=POLL_SLEEP, + impersonation_chain=IMPERSONATION_CHAIN, + fail_on_terminal_state=False, + ) + + +@pytest.fixture +def dataflow_job_messages_trigger(): + return DataflowJobMessagesTrigger( + project_id=PROJECT_ID, + job_id=JOB_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + poll_sleep=POLL_SLEEP, + impersonation_chain=IMPERSONATION_CHAIN, + fail_on_terminal_state=False, + ) + + +@pytest.fixture +def dataflow_job_metrics_trigger(): + return DataflowJobMetricsTrigger( + project_id=PROJECT_ID, + job_id=JOB_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + poll_sleep=POLL_SLEEP, + impersonation_chain=IMPERSONATION_CHAIN, + fail_on_terminal_state=False, + ) + + +@pytest.fixture +def dataflow_job_status_trigger(): + return DataflowJobStatusTrigger( + project_id=PROJECT_ID, + job_id=JOB_ID, + expected_statuses={JobState.JOB_STATE_DONE, JobState.JOB_STATE_FAILED}, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + poll_sleep=POLL_SLEEP, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + class TestTemplateJobStartTrigger: - def test_serialize(self, trigger): - actual_data = trigger.serialize() + def test_serialize(self, template_job_start_trigger): + actual_data = template_job_start_trigger.serialize() expected_data = ( "airflow.providers.google.cloud.triggers.dataflow.TemplateJobStartTrigger", { @@ -75,15 +134,15 @@ def test_serialize(self, trigger): ("cancel_timeout", CANCEL_TIMEOUT), ], ) - def test_get_async_hook(self, trigger, attr, expected): - hook = trigger._get_async_hook() + def test_get_async_hook(self, template_job_start_trigger, attr, expected): + hook = template_job_start_trigger._get_async_hook() actual = hook._hook_kwargs.get(attr) assert actual is not None assert actual == expected @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") - async def test_run_loop_return_success_event(self, mock_job_status, trigger): + async def test_run_loop_return_success_event(self, mock_job_status, template_job_start_trigger): mock_job_status.return_value = JobState.JOB_STATE_DONE expected_event = TriggerEvent( @@ -93,13 +152,13 @@ async def test_run_loop_return_success_event(self, mock_job_status, trigger): "message": "Job completed", } ) - actual_event = await trigger.run().asend(None) + actual_event = await template_job_start_trigger.run().asend(None) assert actual_event == expected_event @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") - async def test_run_loop_return_failed_event(self, mock_job_status, trigger): + async def test_run_loop_return_failed_event(self, mock_job_status, template_job_start_trigger): mock_job_status.return_value = JobState.JOB_STATE_FAILED expected_event = TriggerEvent( @@ -108,13 +167,13 @@ async def test_run_loop_return_failed_event(self, mock_job_status, trigger): "message": f"Dataflow job with id {JOB_ID} has failed its execution", } ) - actual_event = await trigger.run().asend(None) + actual_event = await template_job_start_trigger.run().asend(None) assert actual_event == expected_event @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") - async def test_run_loop_return_stopped_event(self, mock_job_status, trigger): + async def test_run_loop_return_stopped_event(self, mock_job_status, template_job_start_trigger): mock_job_status.return_value = JobState.JOB_STATE_STOPPED expected_event = TriggerEvent( { @@ -122,19 +181,551 @@ async def test_run_loop_return_stopped_event(self, mock_job_status, trigger): "message": f"Dataflow job with id {JOB_ID} was stopped", } ) - actual_event = await trigger.run().asend(None) + actual_event = await template_job_start_trigger.run().asend(None) assert actual_event == expected_event @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") - async def test_run_loop_is_still_running(self, mock_job_status, trigger, caplog): + async def test_run_loop_is_still_running(self, mock_job_status, template_job_start_trigger, caplog): mock_job_status.return_value = JobState.JOB_STATE_RUNNING caplog.set_level(logging.INFO) - task = asyncio.create_task(trigger.run().__anext__()) + task = asyncio.create_task(template_job_start_trigger.run().__anext__()) await asyncio.sleep(0.5) assert not task.done() assert f"Current job status is: {JobState.JOB_STATE_RUNNING}" assert f"Sleeping for {POLL_SLEEP} seconds." + # cancel the task to suppress test warnings + task.cancel() + + +class TestDataflowJobAutoScalingEventTrigger: + def test_serialize(self, dataflow_job_autoscaling_event_trigger): + expected_data = ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobAutoScalingEventTrigger", + { + "project_id": PROJECT_ID, + "job_id": JOB_ID, + "location": LOCATION, + "gcp_conn_id": GCP_CONN_ID, + "poll_sleep": POLL_SLEEP, + "impersonation_chain": IMPERSONATION_CHAIN, + "fail_on_terminal_state": False, + }, + ) + actual_data = dataflow_job_autoscaling_event_trigger.serialize() + assert actual_data == expected_data + + @pytest.mark.parametrize( + "attr, expected", + [ + ("gcp_conn_id", GCP_CONN_ID), + ("poll_sleep", POLL_SLEEP), + ("impersonation_chain", IMPERSONATION_CHAIN), + ], + ) + def test_async_hook(self, dataflow_job_autoscaling_event_trigger, attr, expected): + hook = dataflow_job_autoscaling_event_trigger.async_hook + actual = hook._hook_kwargs.get(attr) + assert actual == expected + + @pytest.mark.parametrize( + "job_status_value", + [ + JobState.JOB_STATE_DONE, + JobState.JOB_STATE_FAILED, + JobState.JOB_STATE_CANCELLED, + JobState.JOB_STATE_UPDATED, + JobState.JOB_STATE_DRAINED, + ], + ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + @mock.patch( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobAutoScalingEventTrigger.list_job_autoscaling_events" + ) + async def test_run_yields_terminal_state_event_if_fail_on_terminal_state( + self, + mock_list_job_autoscaling_events, + mock_job_status, + job_status_value, + dataflow_job_autoscaling_event_trigger, + ): + dataflow_job_autoscaling_event_trigger.fail_on_terminal_state = True + mock_list_job_autoscaling_events.return_value = [] + mock_job_status.return_value = job_status_value + expected_event = TriggerEvent( + { + "status": "error", + "message": f"Job with id '{JOB_ID}' is already in terminal state: {job_status_value.name}", + "result": None, + } + ) + actual_event = await dataflow_job_autoscaling_event_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + @mock.patch( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobAutoScalingEventTrigger.list_job_autoscaling_events" + ) + async def test_run_loop_is_still_running_if_fail_on_terminal_state( + self, + mock_list_job_autoscaling_events, + mock_job_status, + dataflow_job_autoscaling_event_trigger, + caplog, + ): + """Test that DataflowJobAutoScalingEventTrigger is still in loop if the job status is RUNNING.""" + dataflow_job_autoscaling_event_trigger.fail_on_terminal_state = True + mock_job_status.return_value = JobState.JOB_STATE_RUNNING + mock_list_job_autoscaling_events.return_value = [] + caplog.set_level(logging.INFO) + task = asyncio.create_task(dataflow_job_autoscaling_event_trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert task.done() is False + # cancel the task to suppress test warnings + task.cancel() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + @mock.patch( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobAutoScalingEventTrigger.list_job_autoscaling_events" + ) + async def test_run_yields_autoscaling_events( + self, mock_list_job_autoscaling_events, mock_job_status, dataflow_job_autoscaling_event_trigger + ): + mock_job_status.return_value = JobState.JOB_STATE_DONE + test_autoscaling_events = [ + { + "event_type": 2, + "description": {}, + "time": "2024-02-05T13:43:31.066611771Z", + "worker_pool": "Regular", + "current_num_workers": "0", + "target_num_workers": "0", + }, + { + "target_num_workers": "1", + "event_type": 1, + "description": {}, + "time": "2024-02-05T13:43:31.066611771Z", + "worker_pool": "Regular", + "current_num_workers": "0", + }, + ] + mock_list_job_autoscaling_events.return_value = test_autoscaling_events + expected_event = TriggerEvent( + { + "status": "success", + "message": f"Detected 2 autoscaling events for job '{JOB_ID}'", + "result": test_autoscaling_events, + } + ) + actual_event = await dataflow_job_autoscaling_event_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + async def test_run_raises_exception(self, mock_job_status, dataflow_job_autoscaling_event_trigger): + """ + Tests the DataflowJobAutoScalingEventTrigger does fire if there is an exception. + """ + mock_job_status.side_effect = mock.AsyncMock(side_effect=Exception("Test exception")) + expected_event = TriggerEvent({"status": "error", "message": "Test exception", "result": None}) + actual_event = await dataflow_job_autoscaling_event_trigger.run().asend(None) + assert expected_event == actual_event + + +class TestDataflowJobMessagesTrigger: + """Test case for DataflowJobMessagesTrigger""" + + def test_serialize(self, dataflow_job_messages_trigger): + expected_data = ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobMessagesTrigger", + { + "project_id": PROJECT_ID, + "job_id": JOB_ID, + "location": LOCATION, + "gcp_conn_id": GCP_CONN_ID, + "poll_sleep": POLL_SLEEP, + "impersonation_chain": IMPERSONATION_CHAIN, + "fail_on_terminal_state": False, + }, + ) + actual_data = dataflow_job_messages_trigger.serialize() + assert actual_data == expected_data + + @pytest.mark.parametrize( + "attr, expected", + [ + ("gcp_conn_id", GCP_CONN_ID), + ("poll_sleep", POLL_SLEEP), + ("impersonation_chain", IMPERSONATION_CHAIN), + ], + ) + def test_async_hook(self, dataflow_job_messages_trigger, attr, expected): + hook = dataflow_job_messages_trigger.async_hook + actual = hook._hook_kwargs.get(attr) + assert actual == expected + + @pytest.mark.parametrize( + "job_status_value", + [ + JobState.JOB_STATE_DONE, + JobState.JOB_STATE_FAILED, + JobState.JOB_STATE_CANCELLED, + JobState.JOB_STATE_UPDATED, + JobState.JOB_STATE_DRAINED, + ], + ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + @mock.patch( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobMessagesTrigger.list_job_messages" + ) + async def test_run_yields_terminal_state_event_if_fail_on_terminal_state( + self, + mock_list_job_messages, + mock_job_status, + job_status_value, + dataflow_job_messages_trigger, + ): + dataflow_job_messages_trigger.fail_on_terminal_state = True + mock_list_job_messages.return_value = [] + mock_job_status.return_value = job_status_value + expected_event = TriggerEvent( + { + "status": "error", + "message": f"Job with id '{JOB_ID}' is already in terminal state: {job_status_value.name}", + "result": None, + } + ) + actual_event = await dataflow_job_messages_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + @mock.patch( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobMessagesTrigger.list_job_messages" + ) + async def test_run_loop_is_still_running_if_fail_on_terminal_state( + self, + mock_list_job_messages, + mock_job_status, + dataflow_job_messages_trigger, + caplog, + ): + """Test that DataflowJobMessagesTrigger is still in loop if the job status is RUNNING.""" + dataflow_job_messages_trigger.fail_on_terminal_state = True + mock_job_status.return_value = JobState.JOB_STATE_RUNNING + mock_list_job_messages.return_value = [] + caplog.set_level(logging.INFO) + task = asyncio.create_task(dataflow_job_messages_trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert task.done() is False + # cancel the task to suppress test warnings + task.cancel() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + @mock.patch( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobMessagesTrigger.list_job_messages" + ) + async def test_run_yields_job_messages( + self, mock_list_job_messages, mock_job_status, dataflow_job_messages_trigger + ): + mock_job_status.return_value = JobState.JOB_STATE_DONE + test_job_messages = [ + { + "id": "1707695235850", + "time": "2024-02-06T23:47:15.850Z", + "message_text": "msg.", + "message_importance": 5, + }, + { + "id": "1707695635401", + "time": "2024-02-06T23:53:55.401Z", + "message_text": "msg.", + "message_importance": 5, + }, + ] + mock_list_job_messages.return_value = test_job_messages + expected_event = TriggerEvent( + { + "status": "success", + "message": f"Detected 2 job messages for job '{JOB_ID}'", + "result": test_job_messages, + } + ) + actual_event = await dataflow_job_messages_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + async def test_run_raises_exception(self, mock_job_status, dataflow_job_messages_trigger): + """ + Tests the DataflowJobMessagesTrigger does fire if there is an exception. + """ + mock_job_status.side_effect = mock.AsyncMock(side_effect=Exception("Test exception")) + expected_event = TriggerEvent({"status": "error", "message": "Test exception", "result": None}) + actual_event = await dataflow_job_messages_trigger.run().asend(None) + assert expected_event == actual_event + + +class TestDataflowJobMetricsTrigger: + """Test case for DataflowJobMetricsTrigger""" + + def test_serialize(self, dataflow_job_metrics_trigger): + expected_data = ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobMetricsTrigger", + { + "project_id": PROJECT_ID, + "job_id": JOB_ID, + "location": LOCATION, + "gcp_conn_id": GCP_CONN_ID, + "poll_sleep": POLL_SLEEP, + "impersonation_chain": IMPERSONATION_CHAIN, + "fail_on_terminal_state": False, + }, + ) + actual_data = dataflow_job_metrics_trigger.serialize() + assert actual_data == expected_data + + @pytest.mark.parametrize( + "attr, expected", + [ + ("gcp_conn_id", GCP_CONN_ID), + ("poll_sleep", POLL_SLEEP), + ("impersonation_chain", IMPERSONATION_CHAIN), + ], + ) + def test_async_hook(self, dataflow_job_metrics_trigger, attr, expected): + hook = dataflow_job_metrics_trigger.async_hook + actual = hook._hook_kwargs.get(attr) + assert actual == expected + + @pytest.mark.parametrize( + "job_status_value", + [ + JobState.JOB_STATE_DONE, + JobState.JOB_STATE_FAILED, + JobState.JOB_STATE_CANCELLED, + JobState.JOB_STATE_UPDATED, + JobState.JOB_STATE_DRAINED, + ], + ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + @mock.patch("airflow.providers.google.cloud.triggers.dataflow.DataflowJobMetricsTrigger.get_job_metrics") + async def test_run_yields_terminal_state_event_if_fail_on_terminal_state( + self, + mock_get_job_metrics, + mock_job_status, + job_status_value, + dataflow_job_metrics_trigger, + ): + dataflow_job_metrics_trigger.fail_on_terminal_state = True + mock_get_job_metrics.return_value = [] + mock_job_status.return_value = job_status_value + expected_event = TriggerEvent( + { + "status": "error", + "message": f"Job with id '{JOB_ID}' is already in terminal state: {job_status_value.name}", + "result": None, + } + ) + actual_event = await dataflow_job_metrics_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + @mock.patch("airflow.providers.google.cloud.triggers.dataflow.DataflowJobMetricsTrigger.get_job_metrics") + async def test_run_loop_is_still_running_if_fail_on_terminal_state( + self, + mock_get_job_metrics, + mock_job_status, + dataflow_job_metrics_trigger, + caplog, + ): + """Test that DataflowJobMetricsTrigger is still in loop if the job status is RUNNING.""" + dataflow_job_metrics_trigger.fail_on_terminal_state = True + mock_job_status.return_value = JobState.JOB_STATE_RUNNING + mock_get_job_metrics.return_value = [] + caplog.set_level(logging.INFO) + task = asyncio.create_task(dataflow_job_metrics_trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert task.done() is False + # cancel the task to suppress test warnings + task.cancel() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + @mock.patch("airflow.providers.google.cloud.triggers.dataflow.DataflowJobMetricsTrigger.get_job_metrics") + async def test_run_yields_job_messages( + self, mock_get_job_metrics, mock_job_status, dataflow_job_metrics_trigger + ): + mock_job_status.return_value = JobState.JOB_STATE_DONE + test_job_metrics = [ + { + "name": {"origin": "", "name": "", "context": {}}, + "scalar": 0.0, + "update_time": "2024-03-20T12:36:05.229Z", + "kind": "", + "cumulative": False, + }, + { + "name": {"origin": "", "name": "", "context": {}}, + "scalar": 0.0, + "update_time": "2024-03-20T12:36:05.229Z", + "kind": "", + "cumulative": False, + }, + ] + mock_get_job_metrics.return_value = test_job_metrics + expected_event = TriggerEvent( + { + "status": "success", + "message": f"Detected 2 metrics for job '{JOB_ID}'", + "result": test_job_metrics, + } + ) + actual_event = await dataflow_job_metrics_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + async def test_run_raises_exception(self, mock_job_status, dataflow_job_metrics_trigger): + """ + Tests the DataflowJobMetrcisTrigger does fire if there is an exception. + """ + mock_job_status.side_effect = mock.AsyncMock(side_effect=Exception("Test exception")) + expected_event = TriggerEvent({"status": "error", "message": "Test exception", "result": None}) + actual_event = await dataflow_job_metrics_trigger.run().asend(None) + assert expected_event == actual_event + + +class TestDataflowJobStatusTrigger: + """Test case for DataflowJobStatusTrigger""" + + def test_serialize(self, dataflow_job_status_trigger): + expected_data = ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobStatusTrigger", + { + "project_id": PROJECT_ID, + "job_id": JOB_ID, + "expected_statuses": {JobState.JOB_STATE_DONE, JobState.JOB_STATE_FAILED}, + "location": LOCATION, + "gcp_conn_id": GCP_CONN_ID, + "poll_sleep": POLL_SLEEP, + "impersonation_chain": IMPERSONATION_CHAIN, + }, + ) + actual_data = dataflow_job_status_trigger.serialize() + assert actual_data == expected_data + + @pytest.mark.parametrize( + "attr, expected", + [ + ("gcp_conn_id", GCP_CONN_ID), + ("poll_sleep", POLL_SLEEP), + ("impersonation_chain", IMPERSONATION_CHAIN), + ], + ) + def test_async_hook(self, dataflow_job_status_trigger, attr, expected): + hook = dataflow_job_status_trigger.async_hook + actual = hook._hook_kwargs.get(attr) + assert actual == expected + + @pytest.mark.parametrize( + "job_status_value", + [ + JobState.JOB_STATE_DONE, + JobState.JOB_STATE_FAILED, + JobState.JOB_STATE_CANCELLED, + JobState.JOB_STATE_UPDATED, + JobState.JOB_STATE_DRAINED, + ], + ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + async def test_run_yields_terminal_state_event( + self, + mock_job_status, + job_status_value, + dataflow_job_status_trigger, + ): + dataflow_job_status_trigger.expected_statuses = {DataflowJobStatus.JOB_STATE_CANCELLING} + mock_job_status.return_value = job_status_value + expected_event = TriggerEvent( + { + "status": "error", + "message": f"Job with id '{JOB_ID}' is already in terminal state: {job_status_value.name}", + } + ) + actual_event = await dataflow_job_status_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.parametrize( + "job_status_value", + [ + JobState.JOB_STATE_DONE, + JobState.JOB_STATE_RUNNING, + ], + ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + async def test_run_yields_success_event_if_expected_job_status( + self, + mock_job_status, + job_status_value, + dataflow_job_status_trigger, + ): + dataflow_job_status_trigger.expected_statuses = { + DataflowJobStatus.JOB_STATE_DONE, + DataflowJobStatus.JOB_STATE_RUNNING, + } + mock_job_status.return_value = job_status_value + expected_event = TriggerEvent( + { + "status": "success", + "message": f"Job with id '{JOB_ID}' has reached an expected state: {job_status_value.name}", + } + ) + actual_event = await dataflow_job_status_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + async def test_run_loop_is_still_running_if_state_is_not_terminal_or_expected( + self, + mock_job_status, + dataflow_job_status_trigger, + caplog, + ): + """Test that DataflowJobStatusTrigger is still in loop if the job status neither terminal nor expected.""" + dataflow_job_status_trigger.expected_statuses = {DataflowJobStatus.JOB_STATE_DONE} + mock_job_status.return_value = JobState.JOB_STATE_RUNNING + caplog.set_level(logging.INFO) + task = asyncio.create_task(dataflow_job_status_trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert task.done() is False + task.cancel() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status") + async def test_run_raises_exception(self, mock_job_status, dataflow_job_status_trigger): + """ + Tests the DataflowJobStatusTrigger does fire if there is an exception. + """ + mock_job_status.side_effect = mock.AsyncMock(side_effect=Exception("Test exception")) + expected_event = TriggerEvent( + { + "status": "error", + "message": "Test exception", + } + ) + actual_event = await dataflow_job_status_trigger.run().asend(None) + assert expected_event == actual_event diff --git a/tests/system/providers/google/cloud/dataflow/example_dataflow_sensors_deferrable.py b/tests/system/providers/google/cloud/dataflow/example_dataflow_sensors_deferrable.py new file mode 100644 index 000000000000..fbe826b784c9 --- /dev/null +++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_sensors_deferrable.py @@ -0,0 +1,190 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example Airflow DAG for testing Google Dataflow Beam Pipeline Operator with Asynchronous Python in the deferrable mode.""" + +from __future__ import annotations + +import os +from datetime import datetime +from typing import Callable + +from airflow.exceptions import AirflowException +from airflow.models.dag import DAG +from airflow.providers.apache.beam.hooks.beam import BeamRunnerType +from airflow.providers.apache.beam.operators.beam import BeamRunPythonPipelineOperator +from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator +from airflow.providers.google.cloud.sensors.dataflow import ( + DataflowJobAutoScalingEventsSensor, + DataflowJobMessagesSensor, + DataflowJobMetricsSensor, + DataflowJobStatusSensor, +) +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "dataflow_sensors_deferrable" + +RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" +BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" + +GCS_TMP = f"gs://{BUCKET_NAME}/temp/" +GCS_STAGING = f"gs://{BUCKET_NAME}/staging/" +GCS_OUTPUT = f"gs://{BUCKET_NAME}/output" +GCS_PYTHON_SCRIPT = f"gs://{RESOURCE_DATA_BUCKET}/dataflow/python/wordcount_debugging.py" +LOCATION = "europe-west3" + +default_args = { + "dataflow_default_options": { + "tempLocation": GCS_TMP, + "stagingLocation": GCS_STAGING, + } +} + + +with DAG( + DAG_ID, + default_args=default_args, + schedule="@once", + start_date=datetime(2024, 1, 1), + catchup=False, + tags=["example", "dataflow"], +) as dag: + create_bucket = GCSCreateBucketOperator(task_id="create_bucket", bucket_name=BUCKET_NAME) + + start_beam_python_pipeline = BeamRunPythonPipelineOperator( + task_id="start_beam_python_pipeline", + runner=BeamRunnerType.DataflowRunner, + py_file=GCS_PYTHON_SCRIPT, + py_options=[], + pipeline_options={ + "output": GCS_OUTPUT, + }, + py_requirements=["apache-beam[gcp]==2.47.0"], + py_interpreter="python3", + py_system_site_packages=False, + dataflow_config={ + "job_name": "start_beam_python_pipeline", + "location": LOCATION, + "wait_until_finished": False, + }, + ) + + # [START howto_sensor_wait_for_job_status_deferrable] + wait_for_beam_python_pipeline_job_status_def = DataflowJobStatusSensor( + task_id="wait_for_beam_python_pipeline_job_status_def", + job_id="{{task_instance.xcom_pull('start_beam_python_pipeline')['dataflow_job_id']}}", + expected_statuses=DataflowJobStatus.JOB_STATE_DONE, + location=LOCATION, + deferrable=True, + ) + # [END howto_sensor_wait_for_job_status_deferrable] + + # [START howto_sensor_wait_for_job_metric_deferrable] + def check_metric_scalar_gte(metric_name: str, value: int) -> Callable: + """Check is metric greater than equals to given value.""" + + def callback(metrics: list[dict]) -> bool: + dag.log.info("Looking for '%s' >= %d", metric_name, value) + for metric in metrics: + context = metric.get("name", {}).get("context", {}) + original_name = context.get("original_name", "") + tentative = context.get("tentative", "") + if original_name == "Service-cpu_num_seconds" and not tentative: + return metric["scalar"] >= value + raise AirflowException(f"Metric '{metric_name}' not found in metrics") + + return callback + + wait_for_beam_python_pipeline_job_metric_def = DataflowJobMetricsSensor( + task_id="wait_for_beam_python_pipeline_job_metric_def", + job_id="{{task_instance.xcom_pull('start_beam_python_pipeline')['dataflow_job_id']}}", + location=LOCATION, + callback=check_metric_scalar_gte(metric_name="Service-cpu_num_seconds", value=100), + fail_on_terminal_state=False, + deferrable=True, + ) + # [END howto_sensor_wait_for_job_metric_deferrable] + + # [START howto_sensor_wait_for_job_message_deferrable] + def check_job_message(messages: list[dict]) -> bool: + """Check job message.""" + for message in messages: + if "Adding workflow start and stop steps." in message.get("messageText", ""): + return True + return False + + wait_for_beam_python_pipeline_job_message_def = DataflowJobMessagesSensor( + task_id="wait_for_beam_python_pipeline_job_message_def", + job_id="{{task_instance.xcom_pull('start_beam_python_pipeline')['dataflow_job_id']}}", + location=LOCATION, + callback=check_job_message, + fail_on_terminal_state=False, + deferrable=True, + ) + # [END howto_sensor_wait_for_job_message_deferrable] + + # [START howto_sensor_wait_for_job_autoscaling_event_deferrable] + def check_autoscaling_event(autoscaling_events: list[dict]) -> bool: + """Check autoscaling event.""" + for autoscaling_event in autoscaling_events: + if "Worker pool started." in autoscaling_event.get("description", {}).get("messageText", ""): + return True + return False + + wait_for_beam_python_pipeline_job_autoscaling_event_def = DataflowJobAutoScalingEventsSensor( + task_id="wait_for_beam_python_pipeline_job_autoscaling_event_def", + job_id="{{task_instance.xcom_pull('start_beam_python_pipeline')['dataflow_job_id']}}", + location=LOCATION, + callback=check_autoscaling_event, + fail_on_terminal_state=False, + deferrable=True, + ) + # [END howto_sensor_wait_for_job_autoscaling_event_deferrable] + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE + ) + + ( + # TEST SETUP + create_bucket + >> start_beam_python_pipeline + # TEST BODY + >> [ + wait_for_beam_python_pipeline_job_status_def, + wait_for_beam_python_pipeline_job_metric_def, + wait_for_beam_python_pipeline_job_message_def, + wait_for_beam_python_pipeline_job_autoscaling_event_def, + ] + # TEST TEARDOWN + >> delete_bucket + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)