Skip to content

Commit

Permalink
Remove returns in final clause of athena hooks (#43426)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangyulely authored Oct 28, 2024
1 parent 9772dbe commit 1f13f26
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 15 deletions.
40 changes: 25 additions & 15 deletions providers/src/airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,15 @@ def check_query_status(self, query_execution_id: str, use_cache: bool = False) -
state = None
try:
state = response["QueryExecution"]["Status"]["State"]
except Exception:
self.log.exception(
"Exception while getting query state. Query execution id: %s", query_execution_id
)
finally:
except Exception as e:
# The error is being absorbed here and is being handled by the caller.
# The error is being absorbed to implement retries.
return state
self.log.exception(
"Exception while getting query state. Query execution id: %s, Exception: %s",
query_execution_id,
e,
)
return state

def get_state_change_reason(self, query_execution_id: str, use_cache: bool = False) -> str | None:
"""
Expand All @@ -177,15 +178,15 @@ def get_state_change_reason(self, query_execution_id: str, use_cache: bool = Fal
reason = None
try:
reason = response["QueryExecution"]["Status"]["StateChangeReason"]
except Exception:
except Exception as e:
# The error is being absorbed here and is being handled by the caller.
# The error is being absorbed to implement retries.
self.log.exception(
"Exception while getting query state change reason. Query execution id: %s",
"Exception while getting query state change reason. Query execution id: %s, Exception: %s",
query_execution_id,
e,
)
finally:
# The error is being absorbed here and is being handled by the caller.
# The error is being absorbed to implement retries.
return reason
return reason

def get_query_results(
self, query_execution_id: str, next_token_id: str | None = None, max_results: int = 1000
Expand Down Expand Up @@ -287,9 +288,18 @@ def poll_query_status(
)
except AirflowException as error:
# this function does not raise errors to keep previous behavior.
self.log.warning(error)
finally:
return self.check_query_status(query_execution_id)
self.log.warning(
"AirflowException while polling query status. Query execution id: %s, Exception: %s",
query_execution_id,
error,
)
except Exception as e:
self.log.warning(
"Unexpected exception while polling query status. Query execution id: %s, Exception: %s",
query_execution_id,
e,
)
return self.check_query_status(query_execution_id)

def get_output_location(self, query_execution_id: str) -> str:
"""
Expand Down
21 changes: 21 additions & 0 deletions providers/tests/amazon/aws/hooks/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,15 @@ def test_hook_poll_query_with_timeout(self, mock_conn):
mock_conn.return_value.get_query_execution.assert_called_once()
assert result == "RUNNING"

@mock.patch.object(AthenaHook, "get_conn")
def test_hook_poll_query_with_exception(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT
result = self.athena.poll_query_status(
query_execution_id=MOCK_DATA["query_execution_id"], max_polling_attempts=1, sleep_time=0
)
mock_conn.return_value.get_query_execution.assert_called_once()
assert not result

@mock.patch.object(AthenaHook, "get_conn")
def test_hook_get_output_location(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT
Expand Down Expand Up @@ -230,6 +239,18 @@ def test_hook_get_output_location_invalid_response(self, caplog):
self.athena.get_output_location(query_execution_id="PLACEHOLDER")
assert "Error retrieving OutputLocation" in caplog.text

@mock.patch.object(AthenaHook, "get_query_info")
def test_check_query_status_normal(self, mock_get_query_info):
mock_get_query_info.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
state = self.athena.check_query_status(query_execution_id=MOCK_DATA["query_execution_id"])
assert state == "SUCCEEDED"

@mock.patch.object(AthenaHook, "get_query_info")
def test_check_query_status_exception(self, mock_get_query_info):
mock_get_query_info.return_value = MOCK_QUERY_EXECUTION_OUTPUT
state = self.athena.check_query_status(query_execution_id=MOCK_DATA["query_execution_id"])
assert not state

@mock.patch.object(AthenaHook, "get_conn")
def test_hook_get_query_info_caching(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT
Expand Down

0 comments on commit 1f13f26

Please sign in to comment.