From 1613e9ec1c4e5523953e045c8adcef1b9d4ce95d Mon Sep 17 00:00:00 2001 From: raphaelauv Date: Sun, 25 Aug 2024 14:52:55 +0200 Subject: [PATCH] remove soft_fail (#41710) --- airflow/providers/airbyte/sensors/airbyte.py | 8 +- .../alibaba/cloud/sensors/oss_key.py | 8 +- airflow/providers/amazon/aws/sensors/batch.py | 4 +- airflow/providers/amazon/aws/sensors/ecs.py | 8 +- airflow/providers/amazon/aws/sensors/glue.py | 6 +- .../apache/flink/sensors/flink_kubernetes.py | 5 +- .../providers/celery/sensors/celery_queue.py | 9 -- .../kubernetes/sensors/spark_kubernetes.py | 5 +- airflow/providers/common/sql/sensors/sql.py | 14 +-- airflow/providers/common/sql/sensors/sql.pyi | 5 +- .../sensors/databricks_partition.py | 14 +-- .../databricks/sensors/databricks_sql.py | 5 +- airflow/providers/datadog/sensors/datadog.py | 5 +- airflow/providers/dbt/cloud/sensors/dbt.py | 11 +-- .../google/cloud/sensors/bigquery.py | 14 +-- .../google/cloud/sensors/bigquery_dts.py | 5 +- .../google/cloud/sensors/cloud_composer.py | 8 +- .../google/cloud/sensors/dataflow.py | 26 +----- .../google/cloud/sensors/dataform.py | 5 +- .../google/cloud/sensors/datafusion.py | 8 +- .../google/cloud/sensors/dataplex.py | 32 +------ .../google/cloud/sensors/dataproc.py | 17 +--- .../cloud/sensors/dataproc_metastore.py | 8 +- airflow/providers/google/cloud/sensors/gcs.py | 23 +---- .../providers/google/cloud/sensors/looker.py | 14 +-- .../providers/google/cloud/sensors/pubsub.py | 5 +- .../google/cloud/sensors/workflows.py | 5 +- .../sensors/display_video.py | 5 +- airflow/providers/jenkins/sensors/jenkins.py | 5 +- .../microsoft/azure/sensors/data_factory.py | 11 +-- .../providers/microsoft/azure/sensors/wasb.py | 8 +- airflow/providers/sftp/sensors/sftp.py | 7 +- airflow/providers/tableau/sensors/tableau.py | 4 - .../providers/airbyte/sensors/test_airbyte.py | 18 ++-- .../alibaba/cloud/sensors/test_oss_key.py | 12 ++- .../amazon/aws/sensors/test_batch.py | 10 +-- .../flink/sensors/test_flink_kubernetes.py | 13 +-- .../celery/sensors/test_celery_queue.py | 10 +-- .../sensors/test_spark_kubernetes.py | 32 ++----- .../providers/common/sql/sensors/test_sql.py | 59 ++++--------- .../sensors/test_databricks_partition.py | 37 +++----- .../databricks/sensors/test_databricks_sql.py | 10 +-- .../providers/datadog/sensors/test_datadog.py | 10 +-- tests/providers/dbt/cloud/sensors/test_dbt.py | 11 +-- .../google/cloud/sensors/test_bigquery.py | 76 ++++------------ .../google/cloud/sensors/test_bigtable.py | 12 +-- .../cloud/sensors/test_cloud_composer.py | 11 +-- .../google/cloud/sensors/test_dataflow.py | 88 ++++--------------- .../google/cloud/sensors/test_datafusion.py | 18 ++-- .../google/cloud/sensors/test_dataplex.py | 10 +-- .../google/cloud/sensors/test_dataproc.py | 48 ++++------ .../cloud/sensors/test_dataproc_metastore.py | 20 ++--- .../google/cloud/sensors/test_gcs.py | 58 +++--------- .../google/cloud/sensors/test_looker.py | 27 ++---- .../google/cloud/sensors/test_pubsub.py | 18 ++-- .../google/cloud/sensors/test_workflows.py | 10 +-- .../sensors/test_display_video.py | 12 ++- .../providers/jenkins/sensors/test_jenkins.py | 12 +-- .../azure/sensors/test_data_factory.py | 20 ++--- .../microsoft/azure/sensors/test_wasb.py | 16 +--- tests/providers/sftp/sensors/test_sftp.py | 14 ++- .../providers/tableau/sensors/test_tableau.py | 11 +-- 62 files changed, 225 insertions(+), 795 deletions(-) diff --git a/airflow/providers/airbyte/sensors/airbyte.py b/airflow/providers/airbyte/sensors/airbyte.py index 35e552b2d673..36d772d53130 100644 --- a/airflow/providers/airbyte/sensors/airbyte.py +++ b/airflow/providers/airbyte/sensors/airbyte.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Literal, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.airbyte.hooks.airbyte import AirbyteHook from airflow.providers.airbyte.triggers.airbyte import AirbyteSyncTrigger from airflow.sensors.base import BaseSensorOperator @@ -93,16 +93,10 @@ def poke(self, context: Context) -> bool: status = job.json()["status"] if status == hook.FAILED: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Job failed: \n{job}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) elif status == hook.CANCELLED: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Job was cancelled: \n{job}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) elif status == hook.SUCCEEDED: self.log.info("Job %s completed successfully.", self.airbyte_job_id) diff --git a/airflow/providers/alibaba/cloud/sensors/oss_key.py b/airflow/providers/alibaba/cloud/sensors/oss_key.py index 6e2b0dec4c58..c9336581ec77 100644 --- a/airflow/providers/alibaba/cloud/sensors/oss_key.py +++ b/airflow/providers/alibaba/cloud/sensors/oss_key.py @@ -23,7 +23,7 @@ from deprecated.classic import deprecated -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.alibaba.cloud.hooks.oss import OSSHook from airflow.sensors.base import BaseSensorOperator @@ -73,23 +73,17 @@ def poke(self, context: Context): parsed_url = urlsplit(self.bucket_key) if self.bucket_name is None: if parsed_url.netloc == "": - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = "If key is a relative path from root, please provide a bucket_name" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) self.bucket_name = parsed_url.netloc self.bucket_key = parsed_url.path.lstrip("/") else: if parsed_url.scheme != "" or parsed_url.netloc != "": - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = ( "If bucket_name is provided, bucket_key" " should be relative path from root" " level, rather than a full oss:// url" ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) self.log.info("Poking for key : oss://%s/%s", self.bucket_name, self.bucket_key) diff --git a/airflow/providers/amazon/aws/sensors/batch.py b/airflow/providers/amazon/aws/sensors/batch.py index 9c1a29f8098f..6ba1da17eb35 100644 --- a/airflow/providers/amazon/aws/sensors/batch.py +++ b/airflow/providers/amazon/aws/sensors/batch.py @@ -23,7 +23,7 @@ from deprecated import deprecated from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.providers.amazon.aws.triggers.batch import BatchJobTrigger from airflow.sensors.base import BaseSensorOperator @@ -265,6 +265,4 @@ def poke(self, context: Context) -> bool: return False message = f"AWS Batch job queue failed. AWS Batch job queue status: {status}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) diff --git a/airflow/providers/amazon/aws/sensors/ecs.py b/airflow/providers/amazon/aws/sensors/ecs.py index aba3e5592268..049ed8227528 100644 --- a/airflow/providers/amazon/aws/sensors/ecs.py +++ b/airflow/providers/amazon/aws/sensors/ecs.py @@ -35,7 +35,7 @@ from airflow.utils.context import Context -def _check_failed(current_state, target_state, failure_states, soft_fail: bool) -> None: +def _check_failed(current_state, target_state, failure_states) -> None: if (current_state != target_state) and (current_state in failure_states): raise AirflowException( f"Terminal state reached. Current state: {current_state}, Expected state: {target_state}" @@ -86,7 +86,7 @@ def poke(self, context: Context): cluster_state = EcsClusterStates(self.hook.get_cluster_state(cluster_name=self.cluster_name)) self.log.info("Cluster state: %s, waiting for: %s", cluster_state, self.target_state) - _check_failed(cluster_state, self.target_state, self.failure_states, self.soft_fail) + _check_failed(cluster_state, self.target_state, self.failure_states) return cluster_state == self.target_state @@ -132,7 +132,7 @@ def poke(self, context: Context): ) self.log.info("Task Definition state: %s, waiting for: %s", task_definition_state, self.target_state) - _check_failed(task_definition_state, self.target_state, [self.failure_states], self.soft_fail) + _check_failed(task_definition_state, self.target_state, [self.failure_states]) return task_definition_state == self.target_state @@ -172,5 +172,5 @@ def poke(self, context: Context): task_state = EcsTaskStates(self.hook.get_task_state(cluster=self.cluster, task=self.task)) self.log.info("Task state: %s, waiting for: %s", task_state, self.target_state) - _check_failed(task_state, self.target_state, self.failure_states, self.soft_fail) + _check_failed(task_state, self.target_state, self.failure_states) return task_state == self.target_state diff --git a/airflow/providers/amazon/aws/sensors/glue.py b/airflow/providers/amazon/aws/sensors/glue.py index 062e4ab3efd0..1209a782afcd 100644 --- a/airflow/providers/amazon/aws/sensors/glue.py +++ b/airflow/providers/amazon/aws/sensors/glue.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.glue import ( @@ -177,8 +177,6 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None if event["status"] != "success": message = f"Error: AWS Glue data quality ruleset evaluation run: {event}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) self.hook.validate_evaluation_run_results( @@ -300,8 +298,6 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None if event["status"] != "success": message = f"Error: AWS Glue data quality recommendation run: {event}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) if self.show_results: diff --git a/airflow/providers/apache/flink/sensors/flink_kubernetes.py b/airflow/providers/apache/flink/sensors/flink_kubernetes.py index fcca14a8fb5c..39cc7ea8c59b 100644 --- a/airflow/providers/apache/flink/sensors/flink_kubernetes.py +++ b/airflow/providers/apache/flink/sensors/flink_kubernetes.py @@ -21,7 +21,7 @@ from kubernetes import client -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.sensors.base import BaseSensorOperator @@ -125,10 +125,7 @@ def poke(self, context: Context) -> bool: if self.attach_log and application_state in self.FAILURE_STATES + self.SUCCESS_STATES: self._log_driver(application_state, response) if application_state in self.FAILURE_STATES: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Flink application failed with state: {application_state}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) elif application_state in self.SUCCESS_STATES: self.log.info("Flink application ended successfully") diff --git a/airflow/providers/celery/sensors/celery_queue.py b/airflow/providers/celery/sensors/celery_queue.py index 9800ccdb5b09..a985c3006ab0 100644 --- a/airflow/providers/celery/sensors/celery_queue.py +++ b/airflow/providers/celery/sensors/celery_queue.py @@ -21,7 +21,6 @@ from celery.app import control -from airflow.exceptions import AirflowSkipException from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -73,13 +72,5 @@ def poke(self, context: Context) -> bool: return reserved == 0 and scheduled == 0 and active == 0 except KeyError: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Could not locate Celery queue {self.celery_queue}" - if self.soft_fail: - raise AirflowSkipException(message) raise KeyError(message) - except Exception as err: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException from err - raise diff --git a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py index 86a4761b5110..3ed142e07acc 100644 --- a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py @@ -22,7 +22,7 @@ from kubernetes import client -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.sensors.base import BaseSensorOperator @@ -125,10 +125,7 @@ def poke(self, context: Context) -> bool: self._log_driver(application_state, response) if application_state in self.FAILURE_STATES: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Spark application failed with state: {application_state}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) elif application_state in self.SUCCESS_STATES: self.log.info("Spark application ended successfully") diff --git a/airflow/providers/common/sql/sensors/sql.py b/airflow/providers/common/sql/sensors/sql.py index 01173b2baf49..ece2ea241c94 100644 --- a/airflow/providers/common/sql/sensors/sql.py +++ b/airflow/providers/common/sql/sensors/sql.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.sensors.base import BaseSensorOperator @@ -97,10 +97,7 @@ def poke(self, context: Context) -> bool: records = hook.get_records(self.sql, self.parameters) if not records: if self.fail_on_empty: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = "No rows returned, raising as per fail_on_empty flag" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) else: return False @@ -109,25 +106,16 @@ def poke(self, context: Context) -> bool: if self.failure is not None: if callable(self.failure): if self.failure(first_cell): - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Failure criteria met. self.failure({first_cell}) returned True" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) else: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"self.failure is present, but not callable -> {self.failure}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) if self.success is not None: if callable(self.success): return self.success(first_cell) else: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"self.success is present, but not callable -> {self.success}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) return bool(first_cell) diff --git a/airflow/providers/common/sql/sensors/sql.pyi b/airflow/providers/common/sql/sensors/sql.pyi index 12084e4533f4..db92f6d6e02a 100644 --- a/airflow/providers/common/sql/sensors/sql.pyi +++ b/airflow/providers/common/sql/sensors/sql.pyi @@ -32,10 +32,7 @@ Definition of the public interface for airflow.providers.common.sql.sensors.sql isort:skip_file """ from _typeshed import Incomplete -from airflow.exceptions import ( - AirflowException as AirflowException, - AirflowSkipException as AirflowSkipException, -) +from airflow.exceptions import AirflowException as AirflowException from airflow.hooks.base import BaseHook as BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook as DbApiHook from airflow.sensors.base import BaseSensorOperator as BaseSensorOperator diff --git a/airflow/providers/databricks/sensors/databricks_partition.py b/airflow/providers/databricks/sensors/databricks_partition.py index d056cea73ae9..8ae98c99b1f0 100644 --- a/airflow/providers/databricks/sensors/databricks_partition.py +++ b/airflow/providers/databricks/sensors/databricks_partition.py @@ -26,7 +26,7 @@ from databricks.sql.utils import ParamEscaper -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook from airflow.sensors.base import BaseSensorOperator @@ -182,10 +182,7 @@ def _generate_partition_query( partition_columns = self._sql_sensor(f"DESCRIBE DETAIL {table_name}")[0][7] self.log.debug("Partition columns: %s", partition_columns) if len(partition_columns) < 1: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Table {table_name} does not have partitions" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) formatted_opts = "" @@ -207,17 +204,11 @@ def _generate_partition_query( f"""{partition_col}{self.partition_operator}{self.escaper.escape_item(partition_value)}""" ) else: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Column {partition_col} not part of table partitions: {partition_columns}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) else: # Raises exception if the table does not have any partitions. - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = "No partitions specified to check with the sensor." - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) formatted_opts = f"{prefix} {joiner_val.join(output_list)} {suffix}" self.log.debug("Formatted options: %s", formatted_opts) @@ -231,8 +222,5 @@ def poke(self, context: Context) -> bool: if partition_result: return True else: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Specified partition(s): {self.partitions} were not found." - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) diff --git a/airflow/providers/databricks/sensors/databricks_sql.py b/airflow/providers/databricks/sensors/databricks_sql.py index 5f215eed3fbb..6b614f66144c 100644 --- a/airflow/providers/databricks/sensors/databricks_sql.py +++ b/airflow/providers/databricks/sensors/databricks_sql.py @@ -23,7 +23,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook from airflow.sensors.base import BaseSensorOperator @@ -117,13 +117,10 @@ def hook(self) -> DatabricksSqlHook: def _get_results(self) -> bool: """Use the Databricks SQL hook and run the specified SQL query.""" if not (self._http_path or self._sql_warehouse_name): - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = ( "Databricks SQL warehouse/cluster configuration missing. Please specify either" " http_path or sql_warehouse_name." ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) hook = self.hook sql_result = hook.run( diff --git a/airflow/providers/datadog/sensors/datadog.py b/airflow/providers/datadog/sensors/datadog.py index 08db9acdd82b..0eb4d4fb567d 100644 --- a/airflow/providers/datadog/sensors/datadog.py +++ b/airflow/providers/datadog/sensors/datadog.py @@ -21,7 +21,7 @@ from datadog import api -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.datadog.hooks.datadog import DatadogHook from airflow.sensors.base import BaseSensorOperator @@ -89,10 +89,7 @@ def poke(self, context: Context) -> bool: if isinstance(response, dict) and response.get("status", "ok") != "ok": self.log.error("Unexpected Datadog result: %s", response) - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = "Datadog returned unexpected result" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) if self.response_check: diff --git a/airflow/providers/dbt/cloud/sensors/dbt.py b/airflow/providers/dbt/cloud/sensors/dbt.py index 145d3386bd3b..df733feaf3e8 100644 --- a/airflow/providers/dbt/cloud/sensors/dbt.py +++ b/airflow/providers/dbt/cloud/sensors/dbt.py @@ -24,7 +24,7 @@ from deprecated import deprecated from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus from airflow.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger from airflow.providers.dbt.cloud.utils.openlineage import generate_openlineage_events_from_dbt_cloud_run @@ -93,17 +93,11 @@ def poke(self, context: Context) -> bool: job_run_status = self.hook.get_job_run_status(run_id=self.run_id, account_id=self.account_id) if job_run_status == DbtCloudJobRunStatus.ERROR.value: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Job run {self.run_id} has failed." - if self.soft_fail: - raise AirflowSkipException(message) raise DbtCloudJobRunException(message) if job_run_status == DbtCloudJobRunStatus.CANCELLED.value: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Job run {self.run_id} has been cancelled." - if self.soft_fail: - raise AirflowSkipException(message) raise DbtCloudJobRunException(message) return job_run_status == DbtCloudJobRunStatus.SUCCESS.value @@ -141,9 +135,6 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> int: execution was successful. """ if event["status"] in ["error", "cancelled"]: - message = f"Error in dbt: {event['message']}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException() self.log.info(event["message"]) return int(event["run_id"]) diff --git a/airflow/providers/google/cloud/sensors/bigquery.py b/airflow/providers/google/cloud/sensors/bigquery.py index 72e03d7aba54..203ef9361a03 100644 --- a/airflow/providers/google/cloud/sensors/bigquery.py +++ b/airflow/providers/google/cloud/sensors/bigquery.py @@ -26,7 +26,7 @@ from deprecated import deprecated from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook from airflow.providers.google.cloud.triggers.bigquery import ( BigQueryTableExistenceTrigger, @@ -144,15 +144,9 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None if event: if event["status"] == "success": return event["message"] - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = "No event received in trigger callback" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) @@ -260,15 +254,9 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None if event["status"] == "success": return event["message"] - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = "No event received in trigger callback" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) diff --git a/airflow/providers/google/cloud/sensors/bigquery_dts.py b/airflow/providers/google/cloud/sensors/bigquery_dts.py index 99d7acb1b7f3..c7e158db0702 100644 --- a/airflow/providers/google/cloud/sensors/bigquery_dts.py +++ b/airflow/providers/google/cloud/sensors/bigquery_dts.py @@ -24,7 +24,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.bigquery_datatransfer_v1 import TransferState -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.sensors.base import BaseSensorOperator @@ -142,9 +142,6 @@ def poke(self, context: Context) -> bool: self.log.info("Status of %s run: %s", self.run_id, run.state) if run.state in (TransferState.FAILED, TransferState.CANCELLED): - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Transfer {self.run_id} did not succeed" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) return run.state in self.expected_statuses diff --git a/airflow/providers/google/cloud/sensors/cloud_composer.py b/airflow/providers/google/cloud/sensors/cloud_composer.py index 0301466eac0a..aa6cb0a35d72 100644 --- a/airflow/providers/google/cloud/sensors/cloud_composer.py +++ b/airflow/providers/google/cloud/sensors/cloud_composer.py @@ -28,7 +28,7 @@ from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook from airflow.providers.google.cloud.triggers.cloud_composer import ( CloudComposerDAGRunTrigger, @@ -118,15 +118,9 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None if event.get("operation_done"): return event["operation_done"] - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = "No event received in trigger callback" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) diff --git a/airflow/providers/google/cloud/sensors/dataflow.py b/airflow/providers/google/cloud/sensors/dataflow.py index c5d9efc74758..eebd03d8d3a9 100644 --- a/airflow/providers/google/cloud/sensors/dataflow.py +++ b/airflow/providers/google/cloud/sensors/dataflow.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataflow import ( DEFAULT_DATAFLOW_LOCATION, DataflowHook, @@ -117,10 +117,7 @@ def poke(self, context: Context) -> bool: if job_status in self.expected_statuses: return True elif job_status in DataflowJobStatus.TERMINAL_STATES: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) return False @@ -154,9 +151,6 @@ def execute_complete(self, context: Context, event: dict[str, str | list]) -> bo if event["status"] == "success": self.log.info(event["message"]) return True - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(f"Sensor failed with the following message: {event['message']}.") raise AirflowException(f"Sensor failed with the following message: {event['message']}") @cached_property @@ -235,10 +229,7 @@ def poke(self, context: Context) -> bool: ) job_status = job["currentState"] if job_status in DataflowJobStatus.TERMINAL_STATES: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) result = self.hook.fetch_job_metrics_by_id( @@ -279,9 +270,6 @@ def execute_complete(self, context: Context, event: dict[str, str | list]) -> An if event["status"] == "success": self.log.info(event["message"]) return event["result"] if self.callback is None else self.callback(event["result"]) - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(f"Sensor failed with the following message: {event['message']}.") raise AirflowException(f"Sensor failed with the following message: {event['message']}") @cached_property @@ -362,10 +350,7 @@ def poke(self, context: Context) -> bool: ) job_status = job["currentState"] if job_status in DataflowJobStatus.TERMINAL_STATES: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) result = self.hook.fetch_job_messages_by_id( @@ -407,9 +392,6 @@ def execute_complete(self, context: Context, event: dict[str, str | list]) -> An if event["status"] == "success": self.log.info(event["message"]) return event["result"] if self.callback is None else self.callback(event["result"]) - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(f"Sensor failed with the following message: {event['message']}.") raise AirflowException(f"Sensor failed with the following message: {event['message']}") @cached_property @@ -490,10 +472,7 @@ def poke(self, context: Context) -> bool: ) job_status = job["currentState"] if job_status in DataflowJobStatus.TERMINAL_STATES: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) result = self.hook.fetch_job_autoscaling_events_by_id( @@ -534,9 +513,6 @@ def execute_complete(self, context: Context, event: dict[str, str | list]) -> An if event["status"] == "success": self.log.info(event["message"]) return event["result"] if self.callback is None else self.callback(event["result"]) - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(f"Sensor failed with the following message: {event['message']}.") raise AirflowException(f"Sensor failed with the following message: {event['message']}") @cached_property diff --git a/airflow/providers/google/cloud/sensors/dataform.py b/airflow/providers/google/cloud/sensors/dataform.py index 6ad72161c7eb..0e4676749eb4 100644 --- a/airflow/providers/google/cloud/sensors/dataform.py +++ b/airflow/providers/google/cloud/sensors/dataform.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Iterable, Sequence -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataform import DataformHook from airflow.sensors.base import BaseSensorOperator @@ -96,13 +96,10 @@ def poke(self, context: Context) -> bool: workflow_status = workflow_invocation.state if workflow_status is not None: if self.failure_statuses and workflow_status in self.failure_statuses: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = ( f"Workflow Invocation with id '{self.workflow_invocation_id}' " f"state is: {workflow_status}. Terminating sensor..." ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) return workflow_status in self.expected_statuses diff --git a/airflow/providers/google/cloud/sensors/datafusion.py b/airflow/providers/google/cloud/sensors/datafusion.py index 358906859c53..a3bea7635113 100644 --- a/airflow/providers/google/cloud/sensors/datafusion.py +++ b/airflow/providers/google/cloud/sensors/datafusion.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Iterable, Sequence -from airflow.exceptions import AirflowException, AirflowNotFoundException, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.sensors.base import BaseSensorOperator @@ -111,22 +111,16 @@ def poke(self, context: Context) -> bool: ) pipeline_status = pipeline_workflow["status"] except AirflowNotFoundException: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = "Specified Pipeline ID was not found." - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) except AirflowException: pass # Because the pipeline may not be visible in system yet if pipeline_status is not None: if self.failure_statuses and pipeline_status in self.failure_statuses: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = ( f"Pipeline with id '{self.pipeline_id}' state is: {pipeline_status}. " f"Terminating sensor..." ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) self.log.debug( diff --git a/airflow/providers/google/cloud/sensors/dataplex.py b/airflow/providers/google/cloud/sensors/dataplex.py index 97af1acf05dd..ea6d2003d828 100644 --- a/airflow/providers/google/cloud/sensors/dataplex.py +++ b/airflow/providers/google/cloud/sensors/dataplex.py @@ -30,7 +30,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.dataplex_v1.types import DataScanJob -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataplex import ( AirflowDataQualityScanException, AirflowDataQualityScanResultTimeoutException, @@ -118,10 +118,7 @@ def poke(self, context: Context) -> bool: task_status = task.state if task_status == TaskState.DELETING: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Task is going to be deleted {self.dataplex_task_id}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) self.log.info("Current status of the Dataplex task %s => %s", self.dataplex_task_id, task_status) @@ -202,12 +199,9 @@ def poke(self, context: Context) -> bool: if self.result_timeout: duration = self._duration() if duration > self.result_timeout: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = ( f"Timeout: Data Quality scan {self.job_id} is not ready after {self.result_timeout}s" ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowDataQualityScanResultTimeoutException(message) hook = DataplexHook( @@ -227,10 +221,7 @@ def poke(self, context: Context) -> bool: metadata=self.metadata, ) except GoogleAPICallError as e: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Error occurred when trying to retrieve Data Quality scan job: {self.data_scan_id}" - if self.soft_fail: - raise AirflowSkipException(message, e) raise AirflowException(message, e) job_status = job.state @@ -238,26 +229,17 @@ def poke(self, context: Context) -> bool: "Current status of the Dataplex Data Quality scan job %s => %s", self.job_id, job_status ) if job_status == DataScanJob.State.FAILED: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Data Quality scan job failed: {self.job_id}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) if job_status == DataScanJob.State.CANCELLED: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Data Quality scan job cancelled: {self.job_id}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) if self.fail_on_dq_failure: if job_status == DataScanJob.State.SUCCEEDED and not job.data_quality_result.passed: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = ( f"Data Quality job {self.job_id} execution failed due to failure of its scanning " f"rules: {self.data_scan_id}" ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowDataQualityScanException(message) return job_status == DataScanJob.State.SUCCEEDED @@ -330,12 +312,9 @@ def poke(self, context: Context) -> bool: if self.result_timeout: duration = self._duration() if duration > self.result_timeout: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = ( f"Timeout: Data Profile scan {self.job_id} is not ready after {self.result_timeout}s" ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowDataQualityScanResultTimeoutException(message) hook = DataplexHook( @@ -355,10 +334,7 @@ def poke(self, context: Context) -> bool: metadata=self.metadata, ) except GoogleAPICallError as e: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Error occurred when trying to retrieve Data Profile scan job: {self.data_scan_id}" - if self.soft_fail: - raise AirflowSkipException(message, e) raise AirflowException(message, e) job_status = job.state @@ -366,15 +342,9 @@ def poke(self, context: Context) -> bool: "Current status of the Dataplex Data Profile scan job %s => %s", self.job_id, job_status ) if job_status == DataScanJob.State.FAILED: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Data Profile scan job failed: {self.job_id}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) if job_status == DataScanJob.State.CANCELLED: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Data Profile scan job cancelled: {self.job_id}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) return job_status == DataScanJob.State.SUCCEEDED diff --git a/airflow/providers/google/cloud/sensors/dataproc.py b/airflow/providers/google/cloud/sensors/dataproc.py index fad70b1c28e3..f533b56ec5d0 100644 --- a/airflow/providers/google/cloud/sensors/dataproc.py +++ b/airflow/providers/google/cloud/sensors/dataproc.py @@ -25,7 +25,7 @@ from google.api_core.exceptions import ServerError from google.cloud.dataproc_v1.types import Batch, JobStatus -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataproc import DataprocHook from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.sensors.base import BaseSensorOperator @@ -85,13 +85,10 @@ def poke(self, context: Context) -> bool: duration = self._duration() self.log.info("DURATION RUN: %f", duration) if duration > self.wait_timeout: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = ( f"Timeout: dataproc job {self.dataproc_job_id} " f"is not ready after {self.wait_timeout}s" ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err) return False @@ -100,20 +97,14 @@ def poke(self, context: Context) -> bool: state = job.status.state if state == JobStatus.State.ERROR: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Job failed:\n{job}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) elif state in { JobStatus.State.CANCELLED, JobStatus.State.CANCEL_PENDING, JobStatus.State.CANCEL_STARTED, }: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Job was cancelled:\n{job}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) elif JobStatus.State.DONE == state: self.log.debug("Job %s completed successfully.", self.dataproc_job_id) @@ -185,19 +176,13 @@ def poke(self, context: Context) -> bool: state = batch.state if state == Batch.State.FAILED: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = "Batch failed" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) elif state in { Batch.State.CANCELLED, Batch.State.CANCELLING, }: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = "Batch was cancelled." - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) elif state == Batch.State.SUCCEEDED: self.log.debug("Batch %s completed successfully.", self.batch_id) diff --git a/airflow/providers/google/cloud/sensors/dataproc_metastore.py b/airflow/providers/google/cloud/sensors/dataproc_metastore.py index 3ebf5c0f3c1d..9413f4329818 100644 --- a/airflow/providers/google/cloud/sensors/dataproc_metastore.py +++ b/airflow/providers/google/cloud/sensors/dataproc_metastore.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Sequence -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataproc_metastore import DataprocMetastoreHook from airflow.providers.google.cloud.hooks.gcs import parse_json_from_gcs from airflow.sensors.base import BaseSensorOperator @@ -99,20 +99,14 @@ def poke(self, context: Context) -> bool: impersonation_chain=self.impersonation_chain, ) if not (manifest and isinstance(manifest, dict)): - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = ( f"Failed to extract result manifest. " f"Expected not empty dict, but this was received: {manifest}" ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) if manifest.get("status", {}).get("code") != 0: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Request failed: {manifest.get('message')}" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) # Extract actual query results diff --git a/airflow/providers/google/cloud/sensors/gcs.py b/airflow/providers/google/cloud/sensors/gcs.py index c140f3b279d2..b953ef66eb10 100644 --- a/airflow/providers/google/cloud/sensors/gcs.py +++ b/airflow/providers/google/cloud/sensors/gcs.py @@ -28,7 +28,7 @@ from google.cloud.storage.retry import DEFAULT_RETRY from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.cloud.triggers.gcs import ( GCSBlobTrigger, @@ -137,9 +137,6 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> bool: Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ if event["status"] == "error": - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) self.log.info("File %s was found in bucket %s.", self.object, self.bucket) return True @@ -284,15 +281,9 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None "Checking last updated time for object %s in bucket : %s", self.object, self.bucket ) return event["message"] - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = "No event received in trigger callback" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) @@ -382,9 +373,6 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[ self.log.info("Resuming from trigger and checking status") if event["status"] == "success": return event["matches"] - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) @@ -514,13 +502,10 @@ def is_bucket_updated(self, current_objects: set[str]) -> bool: ) return False - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = ( "Illegal behavior: objects were deleted in " f"{os.path.join(self.bucket, self.prefix)} between pokes." ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) if self.last_activity_time: @@ -592,13 +577,7 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None if event: if event["status"] == "success": return event["message"] - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = "No event received in trigger callback" - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) diff --git a/airflow/providers/google/cloud/sensors/looker.py b/airflow/providers/google/cloud/sensors/looker.py index 8c1f618b3367..ef51abcb2e01 100644 --- a/airflow/providers/google/cloud/sensors/looker.py +++ b/airflow/providers/google/cloud/sensors/looker.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.looker import JobStatus, LookerHook from airflow.sensors.base import BaseSensorOperator @@ -54,10 +54,7 @@ def poke(self, context: Context) -> bool: self.hook = LookerHook(looker_conn_id=self.looker_conn_id) if not self.materialization_id: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = "Invalid `materialization_id`." - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) # materialization_id is templated var pulling output from start task @@ -66,22 +63,13 @@ def poke(self, context: Context) -> bool: if status == JobStatus.ERROR.value: msg = status_dict["message"] - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f'PDT materialization job failed. Job id: {self.materialization_id}. Message:\n"{msg}"' - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) elif status == JobStatus.CANCELLED.value: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"PDT materialization job was cancelled. Job id: {self.materialization_id}." - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) elif status == JobStatus.UNKNOWN.value: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"PDT materialization job has unknown status. Job id: {self.materialization_id}." - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) elif status == JobStatus.DONE.value: self.log.debug( diff --git a/airflow/providers/google/cloud/sensors/pubsub.py b/airflow/providers/google/cloud/sensors/pubsub.py index eae3aef05a75..55acee3d7034 100644 --- a/airflow/providers/google/cloud/sensors/pubsub.py +++ b/airflow/providers/google/cloud/sensors/pubsub.py @@ -25,7 +25,7 @@ from google.cloud.pubsub_v1.types import ReceivedMessage from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.pubsub import PubSubHook from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger from airflow.sensors.base import BaseSensorOperator @@ -175,9 +175,6 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[ self.log.info("Sensor pulls messages: %s", event["message"]) return event["message"] self.log.info("Sensor failed: %s", event["message"]) - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) def _default_message_callback( diff --git a/airflow/providers/google/cloud/sensors/workflows.py b/airflow/providers/google/cloud/sensors/workflows.py index d1aa1b7696c6..aeacef4636e2 100644 --- a/airflow/providers/google/cloud/sensors/workflows.py +++ b/airflow/providers/google/cloud/sensors/workflows.py @@ -21,7 +21,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.workflows.executions_v1beta import Execution -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.workflows import WorkflowsHook from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.sensors.base import BaseSensorOperator @@ -101,13 +101,10 @@ def poke(self, context: Context): state = execution.state if state in self.failure_states: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = ( f"Execution {self.execution_id} for workflow {self.execution_id} " f"failed and is in `{state}` state" ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) if state in self.success_states: diff --git a/airflow/providers/google/marketing_platform/sensors/display_video.py b/airflow/providers/google/marketing_platform/sensors/display_video.py index e6f1298ab8f5..869317afa46e 100644 --- a/airflow/providers/google/marketing_platform/sensors/display_video.py +++ b/airflow/providers/google/marketing_platform/sensors/display_video.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Sequence -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.marketing_platform.hooks.display_video import GoogleDisplayVideo360Hook from airflow.sensors.base import BaseSensorOperator @@ -88,10 +88,7 @@ def poke(self, context: Context) -> bool: ) operation = hook.get_sdf_download_operation(operation_name=self.operation_name) if "error" in operation: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f'The operation finished in error with {operation["error"]}' - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) if operation and operation.get("done"): return True diff --git a/airflow/providers/jenkins/sensors/jenkins.py b/airflow/providers/jenkins/sensors/jenkins.py index 264c7a124594..d0018b24eeda 100644 --- a/airflow/providers/jenkins/sensors/jenkins.py +++ b/airflow/providers/jenkins/sensors/jenkins.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Iterable -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.jenkins.hooks.jenkins import JenkinsHook from airflow.sensors.base import BaseSensorOperator @@ -68,11 +68,8 @@ def poke(self, context: Context) -> bool: if build_result in self.target_states: return True else: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = ( f"Build {build_number} finished with a result {build_result}, " f"which does not meet the target state {self.target_states}." ) - if self.soft_fail: - raise AirflowSkipException(message) raise AirflowException(message) diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py b/airflow/providers/microsoft/azure/sensors/data_factory.py index c13737c83145..a8abba255ef0 100644 --- a/airflow/providers/microsoft/azure/sensors/data_factory.py +++ b/airflow/providers/microsoft/azure/sensors/data_factory.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryHook, AzureDataFactoryPipelineRunException, @@ -85,17 +85,11 @@ def poke(self, context: Context) -> bool: ) if pipeline_run_status == AzureDataFactoryPipelineRunStatus.FAILED: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Pipeline run {self.run_id} has failed." - if self.soft_fail: - raise AirflowSkipException(message) raise AzureDataFactoryPipelineRunException(message) if pipeline_run_status == AzureDataFactoryPipelineRunStatus.CANCELLED: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Pipeline run {self.run_id} has been cancelled." - if self.soft_fail: - raise AirflowSkipException(message) raise AzureDataFactoryPipelineRunException(message) return pipeline_run_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED @@ -131,9 +125,6 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: """ if event: if event["status"] == "error": - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) self.log.info(event["message"]) return None diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py b/airflow/providers/microsoft/azure/sensors/wasb.py index bcec7dc8b4be..d5f5cd424bff 100644 --- a/airflow/providers/microsoft/azure/sensors/wasb.py +++ b/airflow/providers/microsoft/azure/sensors/wasb.py @@ -23,7 +23,7 @@ from deprecated import deprecated from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.microsoft.azure.hooks.wasb import WasbHook from airflow.providers.microsoft.azure.triggers.wasb import WasbBlobSensorTrigger, WasbPrefixSensorTrigger from airflow.sensors.base import BaseSensorOperator @@ -104,9 +104,6 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: """ if event: if event["status"] == "error": - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) self.log.info(event["message"]) else: @@ -216,9 +213,6 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: """ if event: if event["status"] == "error": - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) self.log.info(event["message"]) else: diff --git a/airflow/providers/sftp/sensors/sftp.py b/airflow/providers/sftp/sensors/sftp.py index f56ad9341001..f6b076331fe5 100644 --- a/airflow/providers/sftp/sensors/sftp.py +++ b/airflow/providers/sftp/sensors/sftp.py @@ -26,7 +26,7 @@ from paramiko.sftp import SFTP_NO_SUCH_FILE from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.sftp.hooks.sftp import SFTPHook from airflow.providers.sftp.triggers.sftp import SFTPTrigger from airflow.sensors.base import BaseSensorOperator, PokeReturnValue @@ -98,10 +98,7 @@ def poke(self, context: Context) -> PokeReturnValue | bool: self.log.info("Found File %s last modified: %s", actual_file_to_check, mod_time) except OSError as e: if e.errno != SFTP_NO_SUCH_FILE: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException from e - raise e + raise AirflowException from e continue if self.newer_than: diff --git a/airflow/providers/tableau/sensors/tableau.py b/airflow/providers/tableau/sensors/tableau.py index bdc72df003ca..80eb68474625 100644 --- a/airflow/providers/tableau/sensors/tableau.py +++ b/airflow/providers/tableau/sensors/tableau.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Sequence -from airflow.exceptions import AirflowSkipException from airflow.providers.tableau.hooks.tableau import ( TableauHook, TableauJobFailedException, @@ -69,10 +68,7 @@ def poke(self, context: Context) -> bool: self.log.info("Current finishCode is %s (%s)", finish_code.name, finish_code.value) if finish_code in (TableauJobFinishCode.ERROR, TableauJobFinishCode.CANCELED): - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = "The Tableau Refresh Workbook Job failed!" - if self.soft_fail: - raise AirflowSkipException(message) raise TableauJobFailedException(message) return finish_code == TableauJobFinishCode.SUCCESS diff --git a/tests/providers/airbyte/sensors/test_airbyte.py b/tests/providers/airbyte/sensors/test_airbyte.py index 0609028c3ce0..80a7151a6f09 100644 --- a/tests/providers/airbyte/sensors/test_airbyte.py +++ b/tests/providers/airbyte/sensors/test_airbyte.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.airbyte.sensors.airbyte import AirbyteJobSensor @@ -48,20 +48,16 @@ def test_done(self, mock_get_job): mock_get_job.assert_called_once_with(job_id=self.job_id) assert ret - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job") - def test_failed(self, mock_get_job, soft_fail: bool, expected_exception: type[AirflowException]): + def test_failed(self, mock_get_job): mock_get_job.return_value = self.get_job("failed") sensor = AirbyteJobSensor( task_id=self.task_id, airbyte_job_id=self.job_id, airbyte_conn_id=self.airbyte_conn_id, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="Job failed"): + with pytest.raises(AirflowException, match="Job failed"): sensor.poke(context={}) mock_get_job.assert_called_once_with(job_id=self.job_id) @@ -81,20 +77,16 @@ def test_running(self, mock_get_job): assert not ret - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job") - def test_cancelled(self, mock_get_job, soft_fail: bool, expected_exception: type[AirflowException]): + def test_cancelled(self, mock_get_job): mock_get_job.return_value = self.get_job("cancelled") sensor = AirbyteJobSensor( task_id=self.task_id, airbyte_job_id=self.job_id, airbyte_conn_id=self.airbyte_conn_id, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="Job was cancelled"): + with pytest.raises(AirflowException, match="Job was cancelled"): sensor.poke(context={}) mock_get_job.assert_called_once_with(job_id=self.job_id) diff --git a/tests/providers/alibaba/cloud/sensors/test_oss_key.py b/tests/providers/alibaba/cloud/sensors/test_oss_key.py index ba68cdc926ee..2ef3c5b38d27 100644 --- a/tests/providers/alibaba/cloud/sensors/test_oss_key.py +++ b/tests/providers/alibaba/cloud/sensors/test_oss_key.py @@ -22,7 +22,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.alibaba.cloud.sensors.oss_key import OSSKeySensor MODULE_NAME = "airflow.providers.alibaba.cloud.sensors.oss_key" @@ -79,20 +79,18 @@ def test_poke_non_exsiting_key(self, mock_service, oss_key_sensor): assert res is False mock_service.return_value.object_exists.assert_called_once_with(key=MOCK_KEY, bucket_name=MOCK_BUCKET) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(f"{MODULE_NAME}.OSSKeySensor.hook", new_callable=PropertyMock) def test_poke_without_bucket_name( - self, mock_service, oss_key_sensor, soft_fail: bool, expected_exception: AirflowException + self, + mock_service, + oss_key_sensor, ): # Given - oss_key_sensor.soft_fail = soft_fail oss_key_sensor.bucket_name = None mock_service.return_value.object_exists.return_value = False # When, Then with pytest.raises( - expected_exception, match="If key is a relative path from root, please provide a bucket_name" + AirflowException, match="If key is a relative path from root, please provide a bucket_name" ): oss_key_sensor.poke(None) diff --git a/tests/providers/amazon/aws/sensors/test_batch.py b/tests/providers/amazon/aws/sensors/test_batch.py index a8ec1b926bb5..216987127256 100644 --- a/tests/providers/amazon/aws/sensors/test_batch.py +++ b/tests/providers/amazon/aws/sensors/test_batch.py @@ -106,14 +106,10 @@ def test_execute_failure_in_deferrable_mode(self, deferrable_batch_sensor: Batch ), ) @mock.patch.object(BatchClientHook, "get_job_description") - def test_fail_poke( - self, - mock_get_job_description, - batch_sensor: BatchSensor, - state, - ): + def test_fail_poke(self, mock_get_job_description, state): mock_get_job_description.return_value = {"status": state} - with pytest.raises(AirflowException, match=f"Batch sensor failed. AWS Batch job status: {state}"): + batch_sensor = BatchSensor(task_id="batch_job_sensor", job_id=JOB_ID) + with pytest.raises(AirflowException): batch_sensor.poke({}) diff --git a/tests/providers/apache/flink/sensors/test_flink_kubernetes.py b/tests/providers/apache/flink/sensors/test_flink_kubernetes.py index 59b794702ccd..9354c0ea0233 100644 --- a/tests/providers/apache/flink/sensors/test_flink_kubernetes.py +++ b/tests/providers/apache/flink/sensors/test_flink_kubernetes.py @@ -27,7 +27,7 @@ from kubernetes.client.rest import ApiException from airflow import DAG -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.apache.flink.sensors.flink_kubernetes import FlinkKubernetesSensor from airflow.utils import db, timezone @@ -903,20 +903,15 @@ def test_cluster_ready_state(self, mock_get_namespaced_crd, mock_kubernetes_hook version="v1beta1", ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @patch( "kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object", return_value=TEST_ERROR_CLUSTER, ) - def test_cluster_error_state( - self, mock_get_namespaced_crd, mock_kubernetes_hook, soft_fail, expected_exception - ): + def test_cluster_error_state(self, mock_get_namespaced_crd, mock_kubernetes_hook): sensor = FlinkKubernetesSensor( - application_name="flink-stream-example", dag=self.dag, task_id="test_task_id", soft_fail=soft_fail + application_name="flink-stream-example", dag=self.dag, task_id="test_task_id" ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): sensor.poke(None) mock_kubernetes_hook.assert_called_once_with() mock_get_namespaced_crd.assert_called_once_with( diff --git a/tests/providers/celery/sensors/test_celery_queue.py b/tests/providers/celery/sensors/test_celery_queue.py index e1b3d97c9720..6fcd9900bf6e 100644 --- a/tests/providers/celery/sensors/test_celery_queue.py +++ b/tests/providers/celery/sensors/test_celery_queue.py @@ -21,7 +21,6 @@ import pytest -from airflow.exceptions import AirflowSkipException from airflow.providers.celery.sensors.celery_queue import CeleryQueueSensor @@ -57,18 +56,15 @@ def test_poke_fail(self, mock_inspect): test_sensor = self.sensor(celery_queue="test_queue", task_id="test-task") assert not test_sensor.poke(None) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, KeyError), (True, AirflowSkipException)) - ) @patch("celery.app.control.Inspect") - def test_poke_fail_with_exception(self, mock_inspect, soft_fail, expected_exception): + def test_poke_fail_with_exception(self, mock_inspect): mock_inspect_result = mock_inspect.return_value mock_inspect_result.reserved.return_value = {} mock_inspect_result.scheduled.return_value = {} mock_inspect_result.active.return_value = {} - test_sensor = self.sensor(celery_queue="test_queue", task_id="test-task", soft_fail=soft_fail) + test_sensor = self.sensor(celery_queue="test_queue", task_id="test-task") - with pytest.raises(expected_exception): + with pytest.raises(KeyError): test_sensor.poke(None) @patch("celery.app.control.Inspect") diff --git a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py index ac0f56a309e4..cd7ea51fdc6c 100644 --- a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py @@ -24,7 +24,7 @@ from kubernetes.client.rest import ApiException from airflow import DAG -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.cncf.kubernetes.sensors.spark_kubernetes import SparkKubernetesSensor from airflow.utils import db, timezone @@ -597,9 +597,6 @@ def test_completed_application(self, mock_get_namespaced_crd, mock_kubernetes_ho version="v1beta2", ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @patch( "kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object", return_value=TEST_FAILED_APPLICATION, @@ -608,13 +605,9 @@ def test_failed_application( self, mock_get_namespaced_crd, mock_kubernetes_hook, - soft_fail: bool, - expected_exception: type[AirflowException], ): - sensor = SparkKubernetesSensor( - application_name="spark_pi", dag=self.dag, task_id="test_task_id", soft_fail=soft_fail - ) - with pytest.raises(expected_exception): + sensor = SparkKubernetesSensor(application_name="spark_pi", dag=self.dag, task_id="test_task_id") + with pytest.raises(AirflowException): sensor.poke({}) mock_kubernetes_hook.assert_called_once_with() mock_get_namespaced_crd.assert_called_once_with( @@ -705,9 +698,6 @@ def test_pending_rerun_application(self, mock_get_namespaced_crd, mock_kubernete version="v1beta2", ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @patch( "kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object", return_value=TEST_UNKNOWN_APPLICATION, @@ -716,13 +706,9 @@ def test_unknown_application( self, mock_get_namespaced_crd, mock_kubernetes_hook, - soft_fail: bool, - expected_exception: AirflowException, ): - sensor = SparkKubernetesSensor( - application_name="spark_pi", dag=self.dag, task_id="test_task_id", soft_fail=soft_fail - ) - with pytest.raises(expected_exception): + sensor = SparkKubernetesSensor(application_name="spark_pi", dag=self.dag, task_id="test_task_id") + with pytest.raises(AirflowException): sensor.poke({}) mock_kubernetes_hook.assert_called_once_with() mock_get_namespaced_crd.assert_called_once_with( @@ -801,9 +787,6 @@ def test_namespace_from_connection(self, mock_get_namespaced_crd, mock_kubernete version="v1beta2", ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @patch( "kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object", return_value=TEST_FAILED_APPLICATION, @@ -819,17 +802,14 @@ def test_driver_logging_failure( error_log_call, mock_get_namespaced_crd, mock_kube_conn, - soft_fail: bool, - expected_exception: AirflowException, ): sensor = SparkKubernetesSensor( application_name="spark_pi", attach_log=True, dag=self.dag, task_id="test_task_id", - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): sensor.poke({}) mock_log_call.assert_called_once_with( "spark-pi-driver", namespace="default", container="spark-kubernetes-driver" diff --git a/tests/providers/common/sql/sensors/test_sql.py b/tests/providers/common/sql/sensors/test_sql.py index ee07c1b0ec0b..f4437a265a08 100644 --- a/tests/providers/common/sql/sensors/test_sql.py +++ b/tests/providers/common/sql/sensors/test_sql.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.models.dag import DAG from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.sensors.sql import SqlSensor @@ -124,26 +124,20 @@ def test_sql_sensor_postgres_poke(self, mock_hook): mock_get_records.return_value = [["1"]] assert op.poke({}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") - def test_sql_sensor_postgres_poke_fail_on_empty( - self, mock_hook, soft_fail: bool, expected_exception: type[AirflowException] - ): + def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook): op = SqlSensor( task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", fail_on_empty=True, - soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = [] - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): op.poke({}) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") @@ -164,19 +158,16 @@ def test_sql_sensor_postgres_poke_success(self, mock_hook): mock_get_records.return_value = [["1"]] assert not op.poke({}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") def test_sql_sensor_postgres_poke_failure( - self, mock_hook, soft_fail: bool, expected_exception: type[AirflowException] + self, + mock_hook, ): op = SqlSensor( task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", failure=lambda x: x in [1], - soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) @@ -186,15 +177,13 @@ def test_sql_sensor_postgres_poke_failure( assert not op.poke({}) mock_get_records.return_value = [[1]] - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): op.poke({}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") def test_sql_sensor_postgres_poke_failure_success( - self, mock_hook, soft_fail: bool, expected_exception: type[AirflowException] + self, + mock_hook, ): op = SqlSensor( task_id="sql_sensor_check", @@ -202,7 +191,6 @@ def test_sql_sensor_postgres_poke_failure_success( sql="SELECT 1", failure=lambda x: x in [1], success=lambda x: x in [2], - soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) @@ -212,26 +200,20 @@ def test_sql_sensor_postgres_poke_failure_success( assert not op.poke({}) mock_get_records.return_value = [[1]] - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): op.poke({}) mock_get_records.return_value = [[2]] assert op.poke({}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") - def test_sql_sensor_postgres_poke_failure_success_same( - self, mock_hook, soft_fail: bool, expected_exception: type[AirflowException] - ): + def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook): op = SqlSensor( task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", failure=lambda x: x in [1], success=lambda x: x in [1], - soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) @@ -241,52 +223,43 @@ def test_sql_sensor_postgres_poke_failure_success_same( assert not op.poke({}) mock_get_records.return_value = [[1]] - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): op.poke({}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") - def test_sql_sensor_postgres_poke_invalid_failure( - self, mock_hook, soft_fail: bool, expected_exception: type[AirflowException] - ): + def test_sql_sensor_postgres_poke_invalid_failure(self, mock_hook): op = SqlSensor( task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", failure=[1], # type: ignore[arg-type] - soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = [[1]] - with pytest.raises(expected_exception) as ctx: + with pytest.raises(AirflowException) as ctx: op.poke({}) assert "self.failure is present, but not callable -> [1]" == str(ctx.value) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") def test_sql_sensor_postgres_poke_invalid_success( - self, mock_hook, soft_fail: bool, expected_exception: type[AirflowException] + self, + mock_hook, ): op = SqlSensor( task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", success=[1], # type: ignore[arg-type] - soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = [[1]] - with pytest.raises(expected_exception) as ctx: + with pytest.raises(AirflowException) as ctx: op.poke({}) assert "self.success is present, but not callable -> [1]" == str(ctx.value) diff --git a/tests/providers/databricks/sensors/test_databricks_partition.py b/tests/providers/databricks/sensors/test_databricks_partition.py index c9fed9efd67a..e481e625c61b 100644 --- a/tests/providers/databricks/sensors/test_databricks_partition.py +++ b/tests/providers/databricks/sensors/test_databricks_partition.py @@ -23,7 +23,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.models import DAG from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.databricks.sensors.databricks_partition import DatabricksPartitionSensor @@ -102,62 +102,45 @@ def test_partition_sensor(self, patched_poke): patched_poke.return_value = True assert self.partition_sensor.poke({}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @patch("airflow.providers.databricks.sensors.databricks_partition.DatabricksPartitionSensor._sql_sensor") - def test_fail__generate_partition_query(self, _sql_sensor, soft_fail, expected_exception): - self.partition_sensor.soft_fail = soft_fail + def test_fail__generate_partition_query(self, _sql_sensor): table_name = "test" _sql_sensor.return_value = [[[], [], [], [], [], [], [], []]] - with pytest.raises(expected_exception, match=f"Table {table_name} does not have partitions"): + with pytest.raises(AirflowException, match=f"Table {table_name} does not have partitions"): self.partition_sensor._generate_partition_query( prefix="", suffix="", joiner_val="", table_name=table_name ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @patch("airflow.providers.databricks.sensors.databricks_partition.DatabricksPartitionSensor._sql_sensor") - def test_fail__generate_partition_query_with_partition_col_mismatch( - self, _sql_sensor, soft_fail, expected_exception - ): - self.partition_sensor.soft_fail = soft_fail + def test_fail__generate_partition_query_with_partition_col_mismatch(self, _sql_sensor): partition_col = "non_existent_col" partition_columns = ["col1", "col2"] _sql_sensor.return_value = [[[], [], [], [], [], [], [], partition_columns]] - with pytest.raises(expected_exception, match=f"Column {partition_col} not part of table partitions"): + with pytest.raises(AirflowException, match=f"Column {partition_col} not part of table partitions"): self.partition_sensor._generate_partition_query( prefix="", suffix="", joiner_val="", table_name="", opts={partition_col: "1"} ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @patch("airflow.providers.databricks.sensors.databricks_partition.DatabricksPartitionSensor._sql_sensor") def test_fail__generate_partition_query_with_missing_opts( - self, _sql_sensor, soft_fail, expected_exception + self, + _sql_sensor, ): - self.partition_sensor.soft_fail = soft_fail _sql_sensor.return_value = [[[], [], [], [], [], [], [], ["col1", "col2"]]] - with pytest.raises(expected_exception, match="No partitions specified to check with the sensor."): + with pytest.raises(AirflowException, match="No partitions specified to check with the sensor."): self.partition_sensor._generate_partition_query( prefix="", suffix="", joiner_val="", table_name="" ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @patch( "airflow.providers.databricks.sensors.databricks_partition.DatabricksPartitionSensor" "._check_table_partitions" ) - def test_fail_poke(self, _check_table_partitions, soft_fail, expected_exception): - self.partition_sensor.soft_fail = soft_fail + def test_fail_poke(self, _check_table_partitions): partitions = "test" self.partition_sensor.partitions = partitions _check_table_partitions.return_value = False with pytest.raises( - expected_exception, match=rf"Specified partition\(s\): {partitions} were not found." + AirflowException, match=rf"Specified partition\(s\): {partitions} were not found." ): self.partition_sensor.poke(context={}) diff --git a/tests/providers/databricks/sensors/test_databricks_sql.py b/tests/providers/databricks/sensors/test_databricks_sql.py index 7a3961f79fac..11de84996667 100644 --- a/tests/providers/databricks/sensors/test_databricks_sql.py +++ b/tests/providers/databricks/sensors/test_databricks_sql.py @@ -23,7 +23,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.models import DAG from airflow.providers.databricks.sensors.databricks_sql import DatabricksSqlSensor from airflow.utils import timezone @@ -98,15 +98,11 @@ def test_sql_warehouse_http_path(self): with pytest.raises(AirflowException): _sensor_without_sql_warehouse_http._get_results() - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_fail__get_results(self, soft_fail, expected_exception): + def test_fail__get_results(self): self.sensor._http_path = None self.sensor._sql_warehouse_name = None - self.sensor.soft_fail = soft_fail with pytest.raises( - expected_exception, + AirflowException, match="Databricks SQL warehouse/cluster configuration missing." " Please specify either http_path or sql_warehouse_name.", ): diff --git a/tests/providers/datadog/sensors/test_datadog.py b/tests/providers/datadog/sensors/test_datadog.py index 86633ecd5cf0..4bc28b39bbc8 100644 --- a/tests/providers/datadog/sensors/test_datadog.py +++ b/tests/providers/datadog/sensors/test_datadog.py @@ -22,7 +22,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.datadog.sensors.datadog import DatadogSensor from airflow.utils import db @@ -117,12 +117,9 @@ def test_sensor_fail(self, api1, api2): assert not sensor.poke({}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @patch("airflow.providers.datadog.hooks.datadog.api.Event.query") @patch("airflow.providers.datadog.sensors.datadog.api.Event.query") - def test_sensor_fail_with_exception(self, api1, api2, soft_fail, expected_exception): + def test_sensor_fail_with_exception(self, api1, api2): api1.return_value = zero_events api2.return_value = {"status": "error"} @@ -135,7 +132,6 @@ def test_sensor_fail_with_exception(self, api1, api2, soft_fail, expected_except sources=None, tags=None, response_check=None, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): sensor.poke({}) diff --git a/tests/providers/dbt/cloud/sensors/test_dbt.py b/tests/providers/dbt/cloud/sensors/test_dbt.py index 6da4bfc24cc7..c7dd78b1c69c 100644 --- a/tests/providers/dbt/cloud/sensors/test_dbt.py +++ b/tests/providers/dbt/cloud/sensors/test_dbt.py @@ -24,7 +24,6 @@ from airflow.exceptions import ( AirflowException, AirflowProviderDeprecationWarning, - AirflowSkipException, TaskDeferred, ) from airflow.models.connection import Connection @@ -83,9 +82,6 @@ def test_poke(self, mock_job_run_status, job_run_status, expected_poke_result): assert self.sensor.poke({}) == expected_poke_result - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, DbtCloudJobRunException), (True, AirflowSkipException)) - ) @pytest.mark.parametrize( argnames=("job_run_status", "expected_poke_result"), argvalues=[ @@ -94,10 +90,7 @@ def test_poke(self, mock_job_run_status, job_run_status, expected_poke_result): ], ) @patch.object(DbtCloudHook, "get_job_run_status") - def test_poke_with_exception( - self, mock_job_run_status, job_run_status, expected_poke_result, soft_fail: bool, expected_exception - ): - self.sensor.soft_fail = soft_fail + def test_poke_with_exception(self, mock_job_run_status, job_run_status, expected_poke_result): mock_job_run_status.return_value = job_run_status # The sensor should fail if the job run status is 20 (aka Error) or 30 (aka Cancelled). @@ -106,7 +99,7 @@ def test_poke_with_exception( else: error_message = f"Job run {RUN_ID} has been cancelled." - with pytest.raises(expected_exception, match=error_message): + with pytest.raises(DbtCloudJobRunException, match=error_message): self.sensor.poke({}) @mock.patch("airflow.providers.dbt.cloud.sensors.dbt.DbtCloudHook") diff --git a/tests/providers/google/cloud/sensors/test_bigquery.py b/tests/providers/google/cloud/sensors/test_bigquery.py index 36a331c78c8e..0fd22b3d7806 100644 --- a/tests/providers/google/cloud/sensors/test_bigquery.py +++ b/tests/providers/google/cloud/sensors/test_bigquery.py @@ -23,7 +23,6 @@ from airflow.exceptions import ( AirflowException, AirflowProviderDeprecationWarning, - AirflowSkipException, TaskDeferred, ) from airflow.providers.google.cloud.sensors.bigquery import ( @@ -105,10 +104,7 @@ def test_execute_deferred(self, mock_hook): exc.value.trigger, BigQueryTableExistenceTrigger ), "Trigger is not a BigQueryTableExistenceTrigger" - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_execute_deferred_failure(self, soft_fail, expected_exception): + def test_execute_deferred_failure(self): """Tests that an expected exception is raised in case of error event""" task = BigQueryTableExistenceSensor( task_id="task-id", @@ -116,9 +112,8 @@ def test_execute_deferred_failure(self, soft_fail, expected_exception): dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, deferrable=True, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete(context={}, event={"status": "error", "message": "test failure message"}) def test_execute_complete(self): @@ -135,19 +130,12 @@ def test_execute_complete(self): task.execute_complete(context={}, event={"status": "success", "message": "Job completed"}) mock_log_info.assert_called_with("Sensor checks existence of table: %s", table_uri) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_execute_defered_complete_event_none(self, soft_fail, expected_exception): + def test_execute_defered_complete_event_none(self): """Asserts that logging occurs as expected""" task = BigQueryTableExistenceSensor( - task_id="task-id", - project_id=TEST_PROJECT_ID, - dataset_id=TEST_DATASET_ID, - table_id=TEST_TABLE_ID, - soft_fail=soft_fail, + task_id="task-id", project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete(context={}, event=None) @@ -219,10 +207,7 @@ def test_execute_with_deferrable_mode(self, mock_hook): exc.value.trigger, BigQueryTablePartitionExistenceTrigger ), "Trigger is not a BigQueryTablePartitionExistenceTrigger" - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_execute_with_deferrable_mode_execute_failure(self, soft_fail, expected_exception): + def test_execute_with_deferrable_mode_execute_failure(self): """Tests that an AirflowException is raised in case of error event""" task = BigQueryTablePartitionExistenceSensor( task_id="test_task_id", @@ -231,15 +216,11 @@ def test_execute_with_deferrable_mode_execute_failure(self, soft_fail, expected_ table_id=TEST_TABLE_ID, partition_id=TEST_PARTITION_ID, deferrable=True, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete(context={}, event={"status": "error", "message": "test failure message"}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_execute_complete_event_none(self, soft_fail, expected_exception): + def test_execute_complete_event_none(self): """Asserts that logging occurs as expected""" task = BigQueryTablePartitionExistenceSensor( task_id="task-id", @@ -248,9 +229,8 @@ def test_execute_complete_event_none(self, soft_fail, expected_exception): table_id=TEST_TABLE_ID, partition_id=TEST_PARTITION_ID, deferrable=True, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="No event received in trigger callback"): + with pytest.raises(AirflowException, match="No event received in trigger callback"): task.execute_complete(context={}, event=None) def test_execute_complete(self): @@ -308,10 +288,7 @@ def test_big_query_table_existence_sensor_async(self, mock_hook): exc.value.trigger, BigQueryTableExistenceTrigger ), "Trigger is not a BigQueryTableExistenceTrigger" - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_big_query_table_existence_sensor_async_execute_failure(self, soft_fail, expected_exception): + def test_big_query_table_existence_sensor_async_execute_failure(self): """Tests that an expected_exception is raised in case of error event""" with pytest.warns(AirflowProviderDeprecationWarning, match=self.depcrecation_message): task = BigQueryTableExistenceAsyncSensor( @@ -319,9 +296,8 @@ def test_big_query_table_existence_sensor_async_execute_failure(self, soft_fail, project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete(context={}, event={"status": "error", "message": "test failure message"}) def test_big_query_table_existence_sensor_async_execute_complete(self): @@ -338,10 +314,9 @@ def test_big_query_table_existence_sensor_async_execute_complete(self): task.execute_complete(context={}, event={"status": "success", "message": "Job completed"}) mock_log_info.assert_called_with("Sensor checks existence of table: %s", table_uri) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_big_query_sensor_async_execute_complete_event_none(self, soft_fail, expected_exception): + def test_big_query_sensor_async_execute_complete_event_none( + self, + ): """Asserts that logging occurs as expected""" with pytest.warns(AirflowProviderDeprecationWarning, match=self.depcrecation_message): task = BigQueryTableExistenceAsyncSensor( @@ -349,9 +324,8 @@ def test_big_query_sensor_async_execute_complete_event_none(self, soft_fail, exp project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete(context={}, event=None) @@ -384,12 +358,7 @@ def test_big_query_table_existence_partition_sensor_async(self, mock_hook): exc.value.trigger, BigQueryTablePartitionExistenceTrigger ), "Trigger is not a BigQueryTablePartitionExistenceTrigger" - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_big_query_table_existence_partition_sensor_async_execute_failure( - self, soft_fail, expected_exception - ): + def test_big_query_table_existence_partition_sensor_async_execute_failure(self): """Tests that an expected exception is raised in case of error event""" with pytest.warns(AirflowProviderDeprecationWarning, match=self.depcrecation_message): task = BigQueryTableExistencePartitionAsyncSensor( @@ -398,17 +367,11 @@ def test_big_query_table_existence_partition_sensor_async_execute_failure( dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, partition_id=TEST_PARTITION_ID, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete(context={}, event={"status": "error", "message": "test failure message"}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_big_query_table_existence_partition_sensor_async_execute_complete_event_none( - self, soft_fail, expected_exception - ): + def test_big_query_table_existence_partition_sensor_async_execute_complete_event_none(self): """Asserts that logging occurs as expected""" with pytest.warns(AirflowProviderDeprecationWarning, match=self.depcrecation_message): task = BigQueryTableExistencePartitionAsyncSensor( @@ -417,9 +380,8 @@ def test_big_query_table_existence_partition_sensor_async_execute_complete_event dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, partition_id=TEST_PARTITION_ID, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="No event received in trigger callback"): + with pytest.raises(AirflowException, match="No event received in trigger callback"): task.execute_complete(context={}, event=None) def test_big_query_table_existence_partition_sensor_async_execute_complete(self): diff --git a/tests/providers/google/cloud/sensors/test_bigtable.py b/tests/providers/google/cloud/sensors/test_bigtable.py index dea84fb9f951..37bd5eaf8fa0 100644 --- a/tests/providers/google/cloud/sensors/test_bigtable.py +++ b/tests/providers/google/cloud/sensors/test_bigtable.py @@ -24,7 +24,7 @@ from google.cloud.bigtable.instance import Instance from google.cloud.bigtable.table import ClusterState -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.sensors.bigtable import BigtableTableReplicationCompletedSensor PROJECT_ID = "test_project_id" @@ -35,9 +35,6 @@ class BigtableWaitForTableReplicationTest: - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @pytest.mark.parametrize( "missing_attribute, project_id, instance_id, table_id", [ @@ -46,10 +43,8 @@ class BigtableWaitForTableReplicationTest: ], ) @mock.patch("airflow.providers.google.cloud.sensors.bigtable.BigtableHook") - def test_empty_attribute( - self, missing_attribute, project_id, instance_id, table_id, mock_hook, soft_fail, expected_exception - ): - with pytest.raises(expected_exception) as ctx: + def test_empty_attribute(self, missing_attribute, project_id, instance_id, table_id, mock_hook): + with pytest.raises(AirflowException) as ctx: BigtableTableReplicationCompletedSensor( project_id=project_id, instance_id=instance_id, @@ -57,7 +52,6 @@ def test_empty_attribute( task_id="id", gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, - soft_fail=soft_fail, ) err = ctx.value assert str(err) == f"Empty parameter: {missing_attribute}" diff --git a/tests/providers/google/cloud/sensors/test_cloud_composer.py b/tests/providers/google/cloud/sensors/test_cloud_composer.py index d1a8768dafa1..4c1ba79d2544 100644 --- a/tests/providers/google/cloud/sensors/test_cloud_composer.py +++ b/tests/providers/google/cloud/sensors/test_cloud_composer.py @@ -26,7 +26,6 @@ from airflow.exceptions import ( AirflowException, AirflowProviderDeprecationWarning, - AirflowSkipException, TaskDeferred, ) from airflow.providers.google.cloud.sensors.cloud_composer import ( @@ -86,10 +85,9 @@ def test_cloud_composer_existence_sensor_async(self): exc.value.trigger, CloudComposerExecutionTrigger ), "Trigger is not a CloudComposerExecutionTrigger" - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_cloud_composer_existence_sensor_async_execute_failure(self, soft_fail, expected_exception): + def test_cloud_composer_existence_sensor_async_execute_failure( + self, + ): """Tests that an expected exception is raised in case of error event.""" with pytest.warns(AirflowProviderDeprecationWarning, match=DEPRECATION_MESSAGE): task = CloudComposerEnvironmentSensor( @@ -97,9 +95,8 @@ def test_cloud_composer_existence_sensor_async_execute_failure(self, soft_fail, project_id=TEST_PROJECT_ID, region=TEST_REGION, operation_name=TEST_OPERATION_NAME, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="No event received in trigger callback"): + with pytest.raises(AirflowException, match="No event received in trigger callback"): task.execute_complete(context={}, event=None) def test_cloud_composer_existence_sensor_async_execute_complete(self): diff --git a/tests/providers/google/cloud/sensors/test_dataflow.py b/tests/providers/google/cloud/sensors/test_dataflow.py index 519ace52a4a1..50efd44c6990 100644 --- a/tests/providers/google/cloud/sensors/test_dataflow.py +++ b/tests/providers/google/cloud/sensors/test_dataflow.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus from airflow.providers.google.cloud.sensors.dataflow import ( DataflowJobAutoScalingEventsSensor, @@ -77,11 +77,8 @@ def test_poke(self, mock_hook, expected_status, current_status, sensor_return): job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook") - def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): + def test_poke_raise_exception(self, mock_hook): mock_get_job = mock_hook.return_value.get_job task = DataflowJobStatusSensor( task_id=TEST_TASK_ID, @@ -91,12 +88,11 @@ def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): project_id=TEST_PROJECT_ID, gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, - soft_fail=soft_fail, ) mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_CANCELLED} with pytest.raises( - expected_exception, + AirflowException, match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: " f"{DataflowJobStatus.JOB_STATE_CANCELLED}", ): @@ -157,14 +153,7 @@ def test_execute_complete_success(self): ) assert actual_message == expected_result - @pytest.mark.parametrize( - "expected_exception, soft_fail", - ( - (AirflowException, False), - (AirflowSkipException, True), - ), - ) - def test_execute_complete_not_success_status_raises_exception(self, expected_exception, soft_fail): + def test_execute_complete_not_success_status_raises_exception(self): """Tests that AirflowException or AirflowSkipException is raised if the trigger event contains an error.""" task = DataflowJobStatusSensor( task_id=TEST_TASK_ID, @@ -175,9 +164,8 @@ def test_execute_complete_not_success_status_raises_exception(self, expected_exc gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, deferrable=True, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete(context=None, event={"status": "error", "message": "test error message"}) @@ -221,12 +209,8 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state): mock_fetch_job_metrics_by_id.return_value.__getitem__.assert_called_once_with("metrics") callback.assert_called_once_with(mock_fetch_job_metrics_by_id.return_value.__getitem__.return_value) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - ((False, AirflowException), (True, AirflowSkipException)), - ) @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook") - def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): + def test_poke_raise_exception(self, mock_hook): mock_get_job = mock_hook.return_value.get_job mock_fetch_job_messages_by_id = mock_hook.return_value.fetch_job_messages_by_id callback = mock.MagicMock() @@ -240,12 +224,11 @@ def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): project_id=TEST_PROJECT_ID, gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, - soft_fail=soft_fail, ) mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_DONE} with pytest.raises( - expected_exception, + AirflowException, match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: " f"{DataflowJobStatus.JOB_STATE_DONE}", ): @@ -345,14 +328,7 @@ def test_execute_complete_success_with_callback_function(self): ) assert actual_result == expected_result - @pytest.mark.parametrize( - "expected_exception, soft_fail", - ( - (AirflowException, False), - (AirflowSkipException, True), - ), - ) - def test_execute_complete_not_success_status_raises_exception(self, expected_exception, soft_fail): + def test_execute_complete_not_success_status_raises_exception(self): """Tests that AirflowException or AirflowSkipException is raised if the trigger event contains an error.""" task = DataflowJobMetricsSensor( task_id=TEST_TASK_ID, @@ -364,9 +340,8 @@ def test_execute_complete_not_success_status_raises_exception(self, expected_exc gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, deferrable=True, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete( context=None, event={"status": "error", "message": "test error message", "result": None} ) @@ -412,12 +387,8 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state): ) callback.assert_called_once_with(mock_fetch_job_messages_by_id.return_value) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - ((False, AirflowException), (True, AirflowSkipException)), - ) @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook") - def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): + def test_poke_raise_exception(self, mock_hook): mock_get_job = mock_hook.return_value.get_job mock_fetch_job_messages_by_id = mock_hook.return_value.fetch_job_messages_by_id callback = mock.MagicMock() @@ -431,12 +402,11 @@ def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): project_id=TEST_PROJECT_ID, gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, - soft_fail=soft_fail, ) mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_DONE} with pytest.raises( - expected_exception, + AirflowException, match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: " f"{DataflowJobStatus.JOB_STATE_DONE}", ): @@ -534,14 +504,7 @@ def test_execute_complete_success_with_callback_function(self): ) assert actual_result == expected_result - @pytest.mark.parametrize( - "expected_exception, soft_fail", - ( - (AirflowException, False), - (AirflowSkipException, True), - ), - ) - def test_execute_complete_not_success_status_raises_exception(self, expected_exception, soft_fail): + def test_execute_complete_not_success_status_raises_exception(self): """Tests that AirflowException or AirflowSkipException is raised if the trigger event contains an error.""" task = DataflowJobMessagesSensor( task_id=TEST_TASK_ID, @@ -553,9 +516,8 @@ def test_execute_complete_not_success_status_raises_exception(self, expected_exc gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, deferrable=True, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete( context=None, event={"status": "error", "message": "test error message", "result": None} ) @@ -601,15 +563,8 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state): ) callback.assert_called_once_with(mock_fetch_job_autoscaling_events_by_id.return_value) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - ( - (False, AirflowException), - (True, AirflowSkipException), - ), - ) @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook") - def test_poke_raise_exception_on_terminal_state(self, mock_hook, soft_fail, expected_exception): + def test_poke_raise_exception_on_terminal_state(self, mock_hook): mock_get_job = mock_hook.return_value.get_job mock_fetch_job_autoscaling_events_by_id = mock_hook.return_value.fetch_job_autoscaling_events_by_id callback = mock.MagicMock() @@ -623,12 +578,11 @@ def test_poke_raise_exception_on_terminal_state(self, mock_hook, soft_fail, expe project_id=TEST_PROJECT_ID, gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, - soft_fail=soft_fail, ) mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_DONE} with pytest.raises( - expected_exception, + AirflowException, match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: " f"{DataflowJobStatus.JOB_STATE_DONE}", ): @@ -724,14 +678,7 @@ def test_execute_complete_success_with_callback_function(self): ) assert actual_result == expected_result - @pytest.mark.parametrize( - "expected_exception, soft_fail", - ( - (AirflowException, False), - (AirflowSkipException, True), - ), - ) - def test_execute_complete_not_success_status_raises_exception(self, expected_exception, soft_fail): + def test_execute_complete_not_success_status_raises_exception(self): """Tests that AirflowException or AirflowSkipException is raised if the trigger event contains an error.""" task = DataflowJobAutoScalingEventsSensor( task_id=TEST_TASK_ID, @@ -743,9 +690,8 @@ def test_execute_complete_not_success_status_raises_exception(self, expected_exc gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, deferrable=True, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete( context=None, event={"status": "error", "message": "test error message", "result": None}, diff --git a/tests/providers/google/cloud/sensors/test_datafusion.py b/tests/providers/google/cloud/sensors/test_datafusion.py index 5f230931366a..03223903ce62 100644 --- a/tests/providers/google/cloud/sensors/test_datafusion.py +++ b/tests/providers/google/cloud/sensors/test_datafusion.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowNotFoundException, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.google.cloud.hooks.datafusion import PipelineStates from airflow.providers.google.cloud.sensors.datafusion import CloudDataFusionPipelineStateSensor @@ -74,11 +74,8 @@ def test_poke(self, mock_hook, expected_status, current_status, sensor_return): instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.google.cloud.sensors.datafusion.DataFusionHook") - def test_assertion(self, mock_hook, soft_fail, expected_exception): + def test_assertion(self, mock_hook): mock_hook.return_value.get_instance.return_value = {"apiEndpoint": INSTANCE_URL} task = CloudDataFusionPipelineStateSensor( @@ -92,21 +89,17 @@ def test_assertion(self, mock_hook, soft_fail, expected_exception): location=LOCATION, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, - soft_fail=soft_fail, ) mock_hook.return_value.get_pipeline_workflow.return_value = {"status": "FAILED"} with pytest.raises( - expected_exception, + AirflowException, match=f"Pipeline with id '{PIPELINE_ID}' state is: FAILED. Terminating sensor...", ): task.poke(mock.MagicMock()) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.google.cloud.sensors.datafusion.DataFusionHook") - def test_not_found_exception(self, mock_hook, soft_fail, expected_exception): + def test_not_found_exception(self, mock_hook): mock_hook.return_value.get_instance.return_value = {"apiEndpoint": INSTANCE_URL} mock_hook.return_value.get_pipeline_workflow.side_effect = AirflowNotFoundException() @@ -121,11 +114,10 @@ def test_not_found_exception(self, mock_hook, soft_fail, expected_exception): location=LOCATION, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, - soft_fail=soft_fail, ) with pytest.raises( - expected_exception, + AirflowException, match="Specified Pipeline ID was not found.", ): task.poke(mock.MagicMock()) diff --git a/tests/providers/google/cloud/sensors/test_dataplex.py b/tests/providers/google/cloud/sensors/test_dataplex.py index 4d5e72da71f2..ebd99efda17d 100644 --- a/tests/providers/google/cloud/sensors/test_dataplex.py +++ b/tests/providers/google/cloud/sensors/test_dataplex.py @@ -22,7 +22,7 @@ from google.api_core.gapic_v1.method import DEFAULT from google.cloud.dataplex_v1.types import DataScanJob -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataplex import AirflowDataQualityScanResultTimeoutException from airflow.providers.google.cloud.sensors.dataplex import ( DataplexDataProfileJobStatusSensor, @@ -82,11 +82,8 @@ def test_done(self, mock_hook): assert result - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(DATAPLEX_HOOK) - def test_deleting(self, mock_hook, soft_fail, expected_exception): + def test_deleting(self, mock_hook): task = self.create_task(TaskState.DELETING) mock_hook.return_value.get_task.return_value = task @@ -99,10 +96,9 @@ def test_deleting(self, mock_hook, soft_fail, expected_exception): api_version=API_VERSION, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="Task is going to be deleted"): + with pytest.raises(AirflowException, match="Task is going to be deleted"): sensor.poke(context={}) mock_hook.return_value.get_task.assert_called_once_with( diff --git a/tests/providers/google/cloud/sensors/test_dataproc.py b/tests/providers/google/cloud/sensors/test_dataproc.py index f26eae5cf1c0..669a9a09f2a9 100644 --- a/tests/providers/google/cloud/sensors/test_dataproc.py +++ b/tests/providers/google/cloud/sensors/test_dataproc.py @@ -23,7 +23,7 @@ from google.api_core.exceptions import ServerError from google.cloud.dataproc_v1.types import Batch, JobStatus -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.sensors.dataproc import DataprocBatchSensor, DataprocJobSensor from airflow.version import version as airflow_version @@ -66,11 +66,8 @@ def test_done(self, mock_hook): ) assert ret - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_error(self, mock_hook, soft_fail, expected_exception): + def test_error(self, mock_hook): job = self.create_job(JobStatus.State.ERROR) job_id = "job_id" mock_hook.return_value.get_job.return_value = job @@ -82,10 +79,9 @@ def test_error(self, mock_hook, soft_fail, expected_exception): dataproc_job_id=job_id, gcp_conn_id=GCP_CONN_ID, timeout=TIMEOUT, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="Job failed"): + with pytest.raises(AirflowException, match="Job failed"): sensor.poke(context={}) mock_hook.return_value.get_job.assert_called_once_with( @@ -113,11 +109,8 @@ def test_wait(self, mock_hook): ) assert not ret - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_cancelled(self, mock_hook, soft_fail, expected_exception): + def test_cancelled(self, mock_hook): job = self.create_job(JobStatus.State.CANCELLED) job_id = "job_id" mock_hook.return_value.get_job.return_value = job @@ -129,9 +122,8 @@ def test_cancelled(self, mock_hook, soft_fail, expected_exception): dataproc_job_id=job_id, gcp_conn_id=GCP_CONN_ID, timeout=TIMEOUT, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="Job was cancelled"): + with pytest.raises(AirflowException, match="Job was cancelled"): sensor.poke(context={}) mock_hook.return_value.get_job.assert_called_once_with( @@ -170,11 +162,8 @@ def test_wait_timeout(self, mock_hook): result = sensor.poke(context={}) assert not result - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_wait_timeout_raise_exception(self, mock_hook, soft_fail, expected_exception): + def test_wait_timeout_raise_exception(self, mock_hook): job_id = "job_id" mock_hook.return_value.get_job.side_effect = ServerError("Job are not ready") @@ -186,13 +175,12 @@ def test_wait_timeout_raise_exception(self, mock_hook, soft_fail, expected_excep gcp_conn_id=GCP_CONN_ID, timeout=TIMEOUT, wait_timeout=300, - soft_fail=soft_fail, ) sensor._duration = Mock() sensor._duration.return_value = 301 - with pytest.raises(expected_exception, match="Timeout: dataproc job job_id is not ready after 300s"): + with pytest.raises(AirflowException, match="Timeout: dataproc job job_id is not ready after 300s"): sensor.poke(context={}) @@ -223,11 +211,11 @@ def test_succeeded(self, mock_hook): ) assert ret - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_cancelled(self, mock_hook, soft_fail, expected_exception): + def test_cancelled( + self, + mock_hook, + ): batch = self.create_batch(Batch.State.CANCELLED) mock_hook.return_value.get_batch.return_value = batch @@ -238,20 +226,19 @@ def test_cancelled(self, mock_hook, soft_fail, expected_exception): batch_id="batch_id", gcp_conn_id=GCP_CONN_ID, timeout=TIMEOUT, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="Batch was cancelled."): + with pytest.raises(AirflowException, match="Batch was cancelled."): sensor.poke(context={}) mock_hook.return_value.get_batch.assert_called_once_with( batch_id="batch_id", region=GCP_LOCATION, project_id=GCP_PROJECT ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_error(self, mock_hook, soft_fail, expected_exception): + def test_error( + self, + mock_hook, + ): batch = self.create_batch(Batch.State.FAILED) mock_hook.return_value.get_batch.return_value = batch @@ -262,10 +249,9 @@ def test_error(self, mock_hook, soft_fail, expected_exception): batch_id="batch_id", gcp_conn_id=GCP_CONN_ID, timeout=TIMEOUT, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="Batch failed"): + with pytest.raises(AirflowException, match="Batch failed"): sensor.poke(context={}) mock_hook.return_value.get_batch.assert_called_once_with( diff --git a/tests/providers/google/cloud/sensors/test_dataproc_metastore.py b/tests/providers/google/cloud/sensors/test_dataproc_metastore.py index cdc6dbde6c74..92ae82d32569 100644 --- a/tests/providers/google/cloud/sensors/test_dataproc_metastore.py +++ b/tests/providers/google/cloud/sensors/test_dataproc_metastore.py @@ -22,7 +22,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.sensors.dataproc_metastore import MetastoreHivePartitionSensor DATAPROC_METASTORE_SENSOR_PATH = "airflow.providers.google.cloud.sensors.dataproc_metastore.{}" @@ -107,15 +107,10 @@ def test_poke_positive_manifest( ) assert sensor.poke(context={}) == expected_result - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @pytest.mark.parametrize("empty_manifest", [dict(), list(), tuple(), None, ""]) @mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("DataprocMetastoreHook")) @mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("parse_json_from_gcs")) - def test_poke_empty_manifest( - self, mock_parse_json_from_gcs, mock_hook, empty_manifest, soft_fail, expected_exception - ): + def test_poke_empty_manifest(self, mock_parse_json_from_gcs, mock_hook, empty_manifest): mock_parse_json_from_gcs.return_value = empty_manifest sensor = MetastoreHivePartitionSensor( @@ -125,18 +120,14 @@ def test_poke_empty_manifest( table=TEST_TABLE, partitions=[PARTITION_1], gcp_conn_id=GCP_CONN_ID, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): sensor.poke(context={}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("DataprocMetastoreHook")) @mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("parse_json_from_gcs")) - def test_poke_wrong_status(self, mock_parse_json_from_gcs, mock_hook, soft_fail, expected_exception): + def test_poke_wrong_status(self, mock_parse_json_from_gcs, mock_hook): error_message = "Test error message" mock_parse_json_from_gcs.return_value = {"code": 1, "message": error_message} @@ -147,8 +138,7 @@ def test_poke_wrong_status(self, mock_parse_json_from_gcs, mock_hook, soft_fail, table=TEST_TABLE, partitions=[PARTITION_1], gcp_conn_id=GCP_CONN_ID, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match=f"Request failed: {error_message}"): + with pytest.raises(AirflowException, match=f"Request failed: {error_message}"): sensor.poke(context={}) diff --git a/tests/providers/google/cloud/sensors/test_gcs.py b/tests/providers/google/cloud/sensors/test_gcs.py index 68c8fc8f3bb5..f738d6bbba41 100644 --- a/tests/providers/google/cloud/sensors/test_gcs.py +++ b/tests/providers/google/cloud/sensors/test_gcs.py @@ -26,8 +26,6 @@ from airflow.exceptions import ( AirflowProviderDeprecationWarning, - AirflowSensorTimeout, - AirflowSkipException, TaskDeferred, ) from airflow.models.dag import DAG, AirflowException @@ -168,10 +166,7 @@ def test_gcs_object_existence_sensor_deferred(self, mock_hook): task.execute({}) assert isinstance(exc.value.trigger, GCSBlobTrigger), "Trigger is not a GCSBlobTrigger" - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_gcs_object_existence_sensor_deferred_execute_failure(self, soft_fail, expected_exception): + def test_gcs_object_existence_sensor_deferred_execute_failure(self): """Tests that an AirflowException is raised in case of error event when deferrable is set to True""" task = GCSObjectExistenceSensor( task_id="task-id", @@ -179,9 +174,8 @@ def test_gcs_object_existence_sensor_deferred_execute_failure(self, soft_fail, e object=TEST_OBJECT, google_cloud_conn_id=TEST_GCP_CONN_ID, deferrable=True, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) def test_gcs_object_existence_sensor_execute_complete(self): @@ -238,10 +232,7 @@ def test_gcs_object_existence_async_sensor(self, mock_hook): task.execute({}) assert isinstance(exc.value.trigger, GCSBlobTrigger), "Trigger is not a GCSBlobTrigger" - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_gcs_object_existence_async_sensor_execute_failure(self, soft_fail, expected_exception): + def test_gcs_object_existence_async_sensor_execute_failure(self): """Tests that an AirflowException is raised in case of error event""" with pytest.warns(AirflowProviderDeprecationWarning, match=self.depcrecation_message): task = GCSObjectExistenceAsyncSensor( @@ -249,9 +240,8 @@ def test_gcs_object_existence_async_sensor_execute_failure(self, soft_fail, expe bucket=TEST_BUCKET, object=TEST_OBJECT, google_cloud_conn_id=TEST_GCP_CONN_ID, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) def test_gcs_object_existence_async_sensor_execute_complete(self): @@ -348,13 +338,9 @@ def test_gcs_object_update_async_sensor(self, mock_hook): exc.value.trigger, GCSCheckBlobUpdateTimeTrigger ), "Trigger is not a GCSCheckBlobUpdateTimeTrigger" - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_gcs_object_update_async_sensor_execute_failure(self, soft_fail, expected_exception): + def test_gcs_object_update_async_sensor_execute_failure(self): """Tests that an AirflowException is raised in case of error event""" - self.OPERATOR.soft_fail = soft_fail - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): self.OPERATOR.execute_complete( context={}, event={"status": "error", "message": "test failure message"} ) @@ -426,21 +412,17 @@ def test_execute(self, mock_hook): mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix=TEST_PREFIX) assert response == generated_messages - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowSensorTimeout), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook") - def test_execute_timeout(self, mock_hook, soft_fail, expected_exception): + def test_execute_timeout(self, mock_hook): task = GCSObjectsWithPrefixExistenceSensor( task_id="task-id", bucket=TEST_BUCKET, prefix=TEST_PREFIX, poke_interval=0, timeout=1, - soft_fail=soft_fail, ) mock_hook.return_value.list.return_value = [] - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): task.execute(mock.MagicMock) @mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook") @@ -493,15 +475,11 @@ def test_gcs_object_with_prefix_existence_async_sensor(self, mock_hook): self.OPERATOR.execute(mock.MagicMock()) assert isinstance(exc.value.trigger, GCSPrefixBlobTrigger), "Trigger is not a GCSPrefixBlobTrigger" - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) def test_gcs_object_with_prefix_existence_async_sensor_execute_failure( - self, soft_fail, expected_exception + self, ): """Tests that an AirflowException is raised in case of error event""" - self.OPERATOR.soft_fail = soft_fail - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): self.OPERATOR.execute_complete( context={}, event={"status": "error", "message": "test failure message"} ) @@ -549,14 +527,10 @@ def test_get_gcs_hook(self, mock_hook): ) assert mock_hook.return_value == self.sensor.hook - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.google.cloud.sensors.gcs.get_time", mock_time) - def test_files_deleted_between_pokes_throw_error(self, soft_fail, expected_exception): - self.sensor.soft_fail = soft_fail + def test_files_deleted_between_pokes_throw_error(self): self.sensor.is_bucket_updated({"a", "b"}) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): self.sensor.is_bucket_updated({"a"}) @mock.patch("airflow.providers.google.cloud.sensors.gcs.get_time", mock_time) @@ -641,14 +615,10 @@ def test_gcs_upload_session_complete_async_sensor(self, mock_hook): exc.value.trigger, GCSUploadSessionTrigger ), "Trigger is not a GCSUploadSessionTrigger" - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_gcs_upload_session_complete_sensor_execute_failure(self, soft_fail, expected_exception): + def test_gcs_upload_session_complete_sensor_execute_failure(self): """Tests that an AirflowException is raised in case of error event""" - self.OPERATOR.soft_fail = soft_fail - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): self.OPERATOR.execute_complete( context={}, event={"status": "error", "message": "test failure message"} ) diff --git a/tests/providers/google/cloud/sensors/test_looker.py b/tests/providers/google/cloud/sensors/test_looker.py index 8e352340552a..732d70bf7518 100644 --- a/tests/providers/google/cloud/sensors/test_looker.py +++ b/tests/providers/google/cloud/sensors/test_looker.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.looker import JobStatus from airflow.providers.google.cloud.sensors.looker import LookerCheckPdtBuildSensor @@ -51,11 +51,8 @@ def test_done(self, mock_hook): # assert we got a response assert ret - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(SENSOR_PATH.format("LookerHook")) - def test_error(self, mock_hook, soft_fail, expected_exception): + def test_error(self, mock_hook): mock_hook.return_value.pdt_build_status.return_value = { "status": JobStatus.ERROR.value, "message": "test", @@ -66,10 +63,9 @@ def test_error(self, mock_hook, soft_fail, expected_exception): task_id=TASK_ID, looker_conn_id=LOOKER_CONN_ID, materialization_id=TEST_JOB_ID, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="PDT materialization job failed"): + with pytest.raises(AirflowException, match="PDT materialization job failed"): sensor.poke(context={}) # assert hook.pdt_build_status called once @@ -93,11 +89,8 @@ def test_wait(self, mock_hook): # assert we got NO response assert not ret - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(SENSOR_PATH.format("LookerHook")) - def test_cancelled(self, mock_hook, soft_fail, expected_exception): + def test_cancelled(self, mock_hook): mock_hook.return_value.pdt_build_status.return_value = {"status": JobStatus.CANCELLED.value} # run task in mock context @@ -105,23 +98,19 @@ def test_cancelled(self, mock_hook, soft_fail, expected_exception): task_id=TASK_ID, looker_conn_id=LOOKER_CONN_ID, materialization_id=TEST_JOB_ID, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception, match="PDT materialization job was cancelled"): + with pytest.raises(AirflowException, match="PDT materialization job was cancelled"): sensor.poke(context={}) # assert hook.pdt_build_status called once mock_hook.return_value.pdt_build_status.assert_called_once_with(materialization_id=TEST_JOB_ID) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_empty_materialization_id(self, soft_fail, expected_exception): + def test_empty_materialization_id(self): # run task in mock context sensor = LookerCheckPdtBuildSensor( - task_id=TASK_ID, looker_conn_id=LOOKER_CONN_ID, materialization_id="", soft_fail=soft_fail + task_id=TASK_ID, looker_conn_id=LOOKER_CONN_ID, materialization_id="" ) - with pytest.raises(expected_exception, match="^Invalid `materialization_id`.$"): + with pytest.raises(AirflowException, match="^Invalid `materialization_id`.$"): sensor.poke(context={}) diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py b/tests/providers/google/cloud/sensors/test_pubsub.py index 1fb63b82a6d2..a77167dda303 100644 --- a/tests/providers/google/cloud/sensors/test_pubsub.py +++ b/tests/providers/google/cloud/sensors/test_pubsub.py @@ -23,7 +23,7 @@ import pytest from google.cloud.pubsub_v1.types import ReceivedMessage -from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException, TaskDeferred +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger @@ -98,23 +98,19 @@ def test_execute(self, mock_hook): ) assert generated_dicts == response - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowSensorTimeout), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.google.cloud.sensors.pubsub.PubSubHook") - def test_execute_timeout(self, mock_hook, soft_fail, expected_exception): + def test_execute_timeout(self, mock_hook): operator = PubSubPullSensor( task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, poke_interval=0, timeout=1, - soft_fail=soft_fail, ) mock_hook.return_value.pull.return_value = [] - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): operator.execute({}) @mock.patch("airflow.providers.google.cloud.sensors.pubsub.PubSubHook") @@ -171,10 +167,7 @@ def test_pubsub_pull_sensor_async(self): task.execute(context={}) assert isinstance(exc.value.trigger, PubsubPullTrigger), "Trigger is not a PubsubPullTrigger" - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_pubsub_pull_sensor_async_execute_should_throw_exception(self, soft_fail, expected_exception): + def test_pubsub_pull_sensor_async_execute_should_throw_exception(self): """Tests that an AirflowException is raised in case of error event""" operator = PubSubPullSensor( @@ -183,10 +176,9 @@ def test_pubsub_pull_sensor_async_execute_should_throw_exception(self, soft_fail project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, deferrable=True, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): operator.execute_complete( context=mock.MagicMock(), event={"status": "error", "message": "test failure message"} ) diff --git a/tests/providers/google/cloud/sensors/test_workflows.py b/tests/providers/google/cloud/sensors/test_workflows.py index 232c1db0e0fd..12d66ac62dde 100644 --- a/tests/providers/google/cloud/sensors/test_workflows.py +++ b/tests/providers/google/cloud/sensors/test_workflows.py @@ -21,7 +21,7 @@ import pytest from google.cloud.workflows.executions_v1beta import Execution -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.sensors.workflows import WorkflowExecutionSensor BASE_PATH = "airflow.providers.google.cloud.sensors.workflows.{}" @@ -90,11 +90,8 @@ def test_poke_wait(self, mock_hook): assert result is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(BASE_PATH.format("WorkflowsHook")) - def test_poke_failure(self, mock_hook, soft_fail, expected_exception): + def test_poke_failure(self, mock_hook): mock_hook.return_value.get_execution.return_value = mock.MagicMock(state=Execution.State.FAILED) op = WorkflowExecutionSensor( task_id="test_task", @@ -107,7 +104,6 @@ def test_poke_failure(self, mock_hook, soft_fail, expected_exception): metadata=METADATA, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, - soft_fail=soft_fail, ) - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): op.poke({}) diff --git a/tests/providers/google/marketing_platform/sensors/test_display_video.py b/tests/providers/google/marketing_platform/sensors/test_display_video.py index 883b8c34d77b..68621e43bf2a 100644 --- a/tests/providers/google/marketing_platform/sensors/test_display_video.py +++ b/tests/providers/google/marketing_platform/sensors/test_display_video.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.google.marketing_platform.sensors.display_video import ( GoogleDisplayVideo360GetSDFDownloadOperationSensor, GoogleDisplayVideo360RunQuerySensor, @@ -71,22 +71,20 @@ def test_poke(self, mock_base_op, hook_mock): operation_name=operation_name ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch(f"{MODULE_NAME}.GoogleDisplayVideo360Hook") @mock.patch(f"{MODULE_NAME}.BaseSensorOperator") def test_poke_with_exception( - self, mock_base_op, hook_mock, soft_fail: bool, expected_exception: type[AirflowException] + self, + mock_base_op, + hook_mock, ): operation_name = "operation_name" op = GoogleDisplayVideo360GetSDFDownloadOperationSensor( operation_name=operation_name, api_version=API_VERSION, task_id="test_task", - soft_fail=soft_fail, ) hook_mock.return_value.get_sdf_download_operation.return_value = {"error": "error"} - with pytest.raises(expected_exception, match="The operation finished in error with error"): + with pytest.raises(AirflowException, match="The operation finished in error with error"): op.poke(context={}) diff --git a/tests/providers/jenkins/sensors/test_jenkins.py b/tests/providers/jenkins/sensors/test_jenkins.py index df6a7360df95..df3e0bef6671 100644 --- a/tests/providers/jenkins/sensors/test_jenkins.py +++ b/tests/providers/jenkins/sensors/test_jenkins.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.jenkins.hooks.jenkins import JenkinsHook from airflow.providers.jenkins.sensors.jenkins import JenkinsBuildSensor @@ -69,9 +69,6 @@ def test_poke_buliding(self, mock_jenkins, build_number, build_state, result): assert jenkins_mock.get_job_info.call_count == 0 if build_number else 1 jenkins_mock.get_build_info.assert_called_once_with("a_job_on_jenkins", target_build_number) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @pytest.mark.parametrize( "build_number, build_state, result", [ @@ -88,9 +85,7 @@ def test_poke_buliding(self, mock_jenkins, build_number, build_state, result): ], ) @patch("jenkins.Jenkins") - def test_poke_finish_building( - self, mock_jenkins, build_number, build_state, result, soft_fail, expected_exception - ): + def test_poke_finish_building(self, mock_jenkins, build_number, build_state, result): target_build_number = build_number or 10 jenkins_mock = MagicMock() @@ -108,10 +103,9 @@ def test_poke_finish_building( job_name="a_job_on_jenkins", build_number=target_build_number, target_states=["SUCCESS"], - soft_fail=soft_fail, ) if result not in sensor.target_states: - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): sensor.poke(None) assert jenkins_mock.get_build_info.call_count == 2 else: diff --git a/tests/providers/microsoft/azure/sensors/test_data_factory.py b/tests/providers/microsoft/azure/sensors/test_data_factory.py index 78631f036891..6e4288ef9de3 100644 --- a/tests/providers/microsoft/azure/sensors/test_data_factory.py +++ b/tests/providers/microsoft/azure/sensors/test_data_factory.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryHook, AzureDataFactoryPipelineRunException, @@ -111,14 +111,10 @@ def test_adf_pipeline_status_sensor_execute_complete_success(self): self.defered_sensor.execute_complete(context={}, event={"status": "success", "message": msg}) mock_log_info.assert_called_with(msg) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_adf_pipeline_status_sensor_execute_complete_failure(self, soft_fail, expected_exception): + def test_adf_pipeline_status_sensor_execute_complete_failure(self): """Assert execute_complete method fail""" - self.defered_sensor.soft_fail = soft_fail - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): self.defered_sensor.execute_complete(context={}, event={"status": "error", "message": ""}) @@ -150,12 +146,10 @@ def test_adf_pipeline_status_sensor_execute_complete_success(self): self.SENSOR.execute_complete(context={}, event={"status": "success", "message": msg}) mock_log_info.assert_called_with(msg) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_adf_pipeline_status_sensor_execute_complete_failure(self, soft_fail, expected_exception): + def test_adf_pipeline_status_sensor_execute_complete_failure( + self, + ): """Assert execute_complete method fail""" - self.SENSOR.soft_fail = soft_fail - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""}) diff --git a/tests/providers/microsoft/azure/sensors/test_wasb.py b/tests/providers/microsoft/azure/sensors/test_wasb.py index 63ffa45165c4..46a386cb44d6 100644 --- a/tests/providers/microsoft/azure/sensors/test_wasb.py +++ b/tests/providers/microsoft/azure/sensors/test_wasb.py @@ -24,7 +24,7 @@ import pendulum import pytest -from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred +from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import Connection from airflow.models.dag import DAG from airflow.models.dagrun import DagRun @@ -160,14 +160,10 @@ def test_wasb_blob_sensor_execute_complete_success(self, event): self.SENSOR.execute_complete(context={}, event=event) mock_log_info.assert_called_with(event["message"]) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_wasb_blob_sensor_execute_complete_failure(self, soft_fail, expected_exception): + def test_wasb_blob_sensor_execute_complete_failure(self): """Assert execute_complete method raises an exception when the triggerer fires an error event.""" - self.SENSOR.soft_fail = soft_fail - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""}) @@ -289,12 +285,8 @@ def test_wasb_prefix_sensor_execute_complete_success(self, event): self.SENSOR.execute_complete(context={}, event=event) mock_log_info.assert_called_with(event["message"]) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_wasb_prefix_sensor_execute_complete_failure(self, soft_fail, expected_exception): + def test_wasb_prefix_sensor_execute_complete_failure(self): """Assert execute_complete method raises an exception when the triggerer fires an error event.""" - self.SENSOR.soft_fail = soft_fail with pytest.raises(AirflowException): self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""}) diff --git a/tests/providers/sftp/sensors/test_sftp.py b/tests/providers/sftp/sensors/test_sftp.py index 25add45e153f..4d1be081af16 100644 --- a/tests/providers/sftp/sensors/test_sftp.py +++ b/tests/providers/sftp/sensors/test_sftp.py @@ -25,7 +25,7 @@ from paramiko.sftp import SFTP_FAILURE, SFTP_NO_SUCH_FILE from pendulum import datetime as pendulum_datetime, timezone -from airflow.exceptions import AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.sftp.sensors.sftp import SFTPSensor from airflow.sensors.base import PokeReturnValue @@ -52,17 +52,13 @@ def test_file_absent(self, sftp_hook_mock): sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt") assert not output - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, OSError), (True, AirflowSkipException)) - ) @patch("airflow.providers.sftp.sensors.sftp.SFTPHook") - def test_sftp_failure(self, sftp_hook_mock, soft_fail: bool, expected_exception): + def test_sftp_failure(self, sftp_hook_mock): sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(SFTP_FAILURE, "SFTP failure") - sftp_sensor = SFTPSensor( - task_id="unit_test", path="/path/to/file/1970-01-01.txt", soft_fail=soft_fail - ) + + sftp_sensor = SFTPSensor(task_id="unit_test", path="/path/to/file/1970-01-01.txt") context = {"ds": "1970-01-01"} - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): sftp_sensor.poke(context) def test_hook_not_created_during_init(self): diff --git a/tests/providers/tableau/sensors/test_tableau.py b/tests/providers/tableau/sensors/test_tableau.py index 181bc5d6bfa5..7b11aaa00a96 100644 --- a/tests/providers/tableau/sensors/test_tableau.py +++ b/tests/providers/tableau/sensors/test_tableau.py @@ -20,9 +20,8 @@ import pytest -from airflow.exceptions import AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.tableau.sensors.tableau import ( - TableauJobFailedException, TableauJobFinishCode, TableauJobStatusSensor, ) @@ -50,9 +49,6 @@ def test_poke(self, mock_tableau_hook): assert job_finished mock_tableau_hook.get_job_status.assert_called_once_with(job_id=sensor.job_id) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, TableauJobFailedException), (True, AirflowSkipException)) - ) @pytest.mark.parametrize( "finish_code", [ @@ -61,15 +57,14 @@ def test_poke(self, mock_tableau_hook): ], ) @patch("airflow.providers.tableau.sensors.tableau.TableauHook") - def test_poke_failed(self, mock_tableau_hook, finish_code, soft_fail: bool, expected_exception): + def test_poke_failed(self, mock_tableau_hook, finish_code): """ Test poke failed """ mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook) mock_tableau_hook.get_job_status.return_value = finish_code sensor = TableauJobStatusSensor(**self.kwargs) - sensor.soft_fail = soft_fail - with pytest.raises(expected_exception): + with pytest.raises(AirflowException): sensor.poke({}) mock_tableau_hook.get_job_status.assert_called_once_with(job_id=sensor.job_id)