From 1f13f261447b9c5239b86d706d7c4f715a644395 Mon Sep 17 00:00:00 2001 From: yangyulely Date: Mon, 28 Oct 2024 18:43:16 +0800 Subject: [PATCH] Remove returns in final clause of athena hooks (#43426) --- .../providers/amazon/aws/hooks/athena.py | 40 ++++++++++++------- .../tests/amazon/aws/hooks/test_athena.py | 21 ++++++++++ 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/providers/src/airflow/providers/amazon/aws/hooks/athena.py b/providers/src/airflow/providers/amazon/aws/hooks/athena.py index 4969f339dba5..60405a863999 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/athena.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/athena.py @@ -155,14 +155,15 @@ def check_query_status(self, query_execution_id: str, use_cache: bool = False) - state = None try: state = response["QueryExecution"]["Status"]["State"] - except Exception: - self.log.exception( - "Exception while getting query state. Query execution id: %s", query_execution_id - ) - finally: + except Exception as e: # The error is being absorbed here and is being handled by the caller. # The error is being absorbed to implement retries. - return state + self.log.exception( + "Exception while getting query state. Query execution id: %s, Exception: %s", + query_execution_id, + e, + ) + return state def get_state_change_reason(self, query_execution_id: str, use_cache: bool = False) -> str | None: """ @@ -177,15 +178,15 @@ def get_state_change_reason(self, query_execution_id: str, use_cache: bool = Fal reason = None try: reason = response["QueryExecution"]["Status"]["StateChangeReason"] - except Exception: + except Exception as e: + # The error is being absorbed here and is being handled by the caller. + # The error is being absorbed to implement retries. self.log.exception( - "Exception while getting query state change reason. Query execution id: %s", + "Exception while getting query state change reason. Query execution id: %s, Exception: %s", query_execution_id, + e, ) - finally: - # The error is being absorbed here and is being handled by the caller. - # The error is being absorbed to implement retries. - return reason + return reason def get_query_results( self, query_execution_id: str, next_token_id: str | None = None, max_results: int = 1000 @@ -287,9 +288,18 @@ def poll_query_status( ) except AirflowException as error: # this function does not raise errors to keep previous behavior. - self.log.warning(error) - finally: - return self.check_query_status(query_execution_id) + self.log.warning( + "AirflowException while polling query status. Query execution id: %s, Exception: %s", + query_execution_id, + error, + ) + except Exception as e: + self.log.warning( + "Unexpected exception while polling query status. Query execution id: %s, Exception: %s", + query_execution_id, + e, + ) + return self.check_query_status(query_execution_id) def get_output_location(self, query_execution_id: str) -> str: """ diff --git a/providers/tests/amazon/aws/hooks/test_athena.py b/providers/tests/amazon/aws/hooks/test_athena.py index 3262bb473a0d..5e75cf3d09b9 100644 --- a/providers/tests/amazon/aws/hooks/test_athena.py +++ b/providers/tests/amazon/aws/hooks/test_athena.py @@ -201,6 +201,15 @@ def test_hook_poll_query_with_timeout(self, mock_conn): mock_conn.return_value.get_query_execution.assert_called_once() assert result == "RUNNING" + @mock.patch.object(AthenaHook, "get_conn") + def test_hook_poll_query_with_exception(self, mock_conn): + mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT + result = self.athena.poll_query_status( + query_execution_id=MOCK_DATA["query_execution_id"], max_polling_attempts=1, sleep_time=0 + ) + mock_conn.return_value.get_query_execution.assert_called_once() + assert not result + @mock.patch.object(AthenaHook, "get_conn") def test_hook_get_output_location(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT @@ -230,6 +239,18 @@ def test_hook_get_output_location_invalid_response(self, caplog): self.athena.get_output_location(query_execution_id="PLACEHOLDER") assert "Error retrieving OutputLocation" in caplog.text + @mock.patch.object(AthenaHook, "get_query_info") + def test_check_query_status_normal(self, mock_get_query_info): + mock_get_query_info.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION + state = self.athena.check_query_status(query_execution_id=MOCK_DATA["query_execution_id"]) + assert state == "SUCCEEDED" + + @mock.patch.object(AthenaHook, "get_query_info") + def test_check_query_status_exception(self, mock_get_query_info): + mock_get_query_info.return_value = MOCK_QUERY_EXECUTION_OUTPUT + state = self.athena.check_query_status(query_execution_id=MOCK_DATA["query_execution_id"]) + assert not state + @mock.patch.object(AthenaHook, "get_conn") def test_hook_get_query_info_caching(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT