From aa2a937e5c51d430cedcc1dfcbf516f0bb8115c1 Mon Sep 17 00:00:00 2001 From: Lorin Dawson <22798188+R7L208@users.noreply.github.com> Date: Tue, 12 Nov 2024 20:50:55 -0700 Subject: [PATCH] Allow Databricks SQL hook to cancel timed out queries (#42668) --- .../providers/databricks/exceptions.py | 32 +++ .../databricks/hooks/databricks_sql.py | 41 +++- .../databricks/operators/databricks_sql.py | 5 + .../databricks/hooks/test_databricks_sql.py | 231 ++++++++++++++---- providers/tests/databricks/test_exceptions.py | 33 +++ providers/tests/edge/executors/__init__.py | 1 - 6 files changed, 287 insertions(+), 56 deletions(-) create mode 100644 providers/src/airflow/providers/databricks/exceptions.py create mode 100644 providers/tests/databricks/test_exceptions.py diff --git a/providers/src/airflow/providers/databricks/exceptions.py b/providers/src/airflow/providers/databricks/exceptions.py new file mode 100644 index 000000000000..0488c975ed15 --- /dev/null +++ b/providers/src/airflow/providers/databricks/exceptions.py @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# Note: Any AirflowException raised is expected to cause the TaskInstance +# to be marked in an ERROR state +"""Exceptions used by Databricks Provider.""" + +from __future__ import annotations + +from airflow.exceptions import AirflowException + + +class DatabricksSqlExecutionError(AirflowException): + """Raised when there is an error in sql execution.""" + + +class DatabricksSqlExecutionTimeout(DatabricksSqlExecutionError): + """Raised when a sql execution times out.""" diff --git a/providers/src/airflow/providers/databricks/hooks/databricks_sql.py b/providers/src/airflow/providers/databricks/hooks/databricks_sql.py index fed85835d095..6d4f679b2eed 100644 --- a/providers/src/airflow/providers/databricks/hooks/databricks_sql.py +++ b/providers/src/airflow/providers/databricks/hooks/databricks_sql.py @@ -16,10 +16,12 @@ # under the License. from __future__ import annotations +import threading import warnings from collections import namedtuple from contextlib import closing from copy import copy +from datetime import timedelta from typing import ( TYPE_CHECKING, Any, @@ -35,8 +37,12 @@ from databricks import sql # type: ignore[attr-defined] -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import ( + AirflowException, + AirflowProviderDeprecationWarning, +) from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results +from airflow.providers.databricks.exceptions import DatabricksSqlExecutionError, DatabricksSqlExecutionTimeout from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook if TYPE_CHECKING: @@ -49,6 +55,16 @@ T = TypeVar("T") +def create_timeout_thread(cur, execution_timeout: timedelta | None) -> threading.Timer | None: + if execution_timeout is not None: + seconds_to_timeout = execution_timeout.total_seconds() + t = threading.Timer(seconds_to_timeout, cur.connection.cancel) + else: + t = None + + return t + + class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): """ Hook to interact with Databricks SQL. @@ -184,6 +200,7 @@ def run( handler: None = ..., split_statements: bool = ..., return_last: bool = ..., + execution_timeout: timedelta | None = None, ) -> None: ... @overload @@ -195,6 +212,7 @@ def run( handler: Callable[[Any], T] = ..., split_statements: bool = ..., return_last: bool = ..., + execution_timeout: timedelta | None = None, ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ... def run( @@ -205,6 +223,7 @@ def run( handler: Callable[[Any], T] | None = None, split_statements: bool = True, return_last: bool = True, + execution_timeout: timedelta | None = None, ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: """ Run a command or a list of commands. @@ -224,6 +243,8 @@ def run( :param return_last: Whether to return result for only last statement or for all after split :return: return only result of the LAST SQL expression if handler was provided unless return_last is set to False. + :param execution_timeout: max time allowed for the execution of this task instance, if it goes beyond + it will raise and fail. """ self.descriptions = [] if isinstance(sql, str): @@ -248,7 +269,23 @@ def run( self.set_autocommit(conn, autocommit) with closing(conn.cursor()) as cur: - self._run_command(cur, sql_statement, parameters) # type: ignore[attr-defined] + t = create_timeout_thread(cur, execution_timeout) + + # TODO: adjust this to make testing easier + try: + self._run_command(cur, sql_statement, parameters) # type: ignore[attr-defined] + except Exception as e: + if t is None or t.is_alive(): + raise DatabricksSqlExecutionError( + f"Error running SQL statement: {sql_statement}. {str(e)}" + ) + raise DatabricksSqlExecutionTimeout( + f"Timeout threshold exceeded for SQL statement: {sql_statement} was cancelled." + ) + finally: + if t is not None: + t.cancel() + if handler is not None: raw_result = handler(cur) if self.return_tuple: diff --git a/providers/src/airflow/providers/databricks/operators/databricks_sql.py b/providers/src/airflow/providers/databricks/operators/databricks_sql.py index 7e59fc2a9d59..0998bc9a6ef5 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks_sql.py +++ b/providers/src/airflow/providers/databricks/operators/databricks_sql.py @@ -353,3 +353,8 @@ def execute(self, context: Context) -> Any: self.log.info("Executing: %s", sql) hook = self._get_hook() hook.run(sql) + + def on_kill(self) -> None: + # NB: on_kill isn't required for this operator since query cancelling gets + # handled in `DatabricksSqlHook.run()` method which is called in `execute()` + ... diff --git a/providers/tests/databricks/hooks/test_databricks_sql.py b/providers/tests/databricks/hooks/test_databricks_sql.py index fc1582db5d90..3eb628d46c2b 100644 --- a/providers/tests/databricks/hooks/test_databricks_sql.py +++ b/providers/tests/databricks/hooks/test_databricks_sql.py @@ -18,17 +18,19 @@ # from __future__ import annotations +import threading from collections import namedtuple +from datetime import timedelta from unittest import mock -from unittest.mock import patch +from unittest.mock import PropertyMock, patch import pytest from databricks.sql.types import Row -from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import Connection from airflow.providers.common.sql.hooks.sql import fetch_all_handler -from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook +from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook, create_timeout_thread from airflow.utils.session import provide_session pytestmark = pytest.mark.db_test @@ -56,6 +58,53 @@ def databricks_hook(): return DatabricksSqlHook(sql_endpoint_name="Test", return_tuple=True) +@pytest.fixture +def mock_get_conn(): + # Start the patcher + mock_patch = patch("airflow.providers.databricks.hooks.databricks_sql.DatabricksSqlHook.get_conn") + mock_conn = mock_patch.start() + # Use yield to provide the mock object + yield mock_conn + # Stop the patcher + mock_patch.stop() + + +@pytest.fixture +def mock_get_requests(): + # Start the patcher + mock_patch = patch("airflow.providers.databricks.hooks.databricks_base.requests") + mock_requests = mock_patch.start() + + # Configure the mock object + mock_requests.codes.ok = 200 + mock_requests.get.return_value.json.return_value = { + "endpoints": [ + { + "id": "1264e5078741679a", + "name": "Test", + "odbc_params": { + "hostname": "xx.cloud.databricks.com", + "path": "/sql/1.0/endpoints/1264e5078741679a", + }, + } + ] + } + status_code_mock = PropertyMock(return_value=200) + type(mock_requests.get.return_value).status_code = status_code_mock + + # Yield the mock object + yield mock_requests + + # Stop the patcher after the test + mock_patch.stop() + + +@pytest.fixture +def mock_timer(): + with patch("threading.Timer") as mock_timer: + yield mock_timer + + def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: return [(field,) for field in fields] @@ -65,13 +114,14 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: @pytest.mark.parametrize( - "return_last, split_statements, sql, cursor_calls, return_tuple," + "return_last, split_statements, sql, execution_timeout, cursor_calls, return_tuple," "cursor_descriptions, cursor_results, hook_descriptions, hook_results, ", [ pytest.param( True, False, "select * from test.test", + None, ["select * from test.test"], False, [["id", "value"]], @@ -84,6 +134,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: False, False, "select * from test.test;", + None, ["select * from test.test"], False, [["id", "value"]], @@ -96,6 +147,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: True, True, "select * from test.test;", + None, ["select * from test.test"], False, [["id", "value"]], @@ -108,6 +160,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: False, True, "select * from test.test;", + None, ["select * from test.test"], False, [["id", "value"]], @@ -120,6 +173,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: True, True, "select * from test.test;select * from test.test2;", + None, ["select * from test.test", "select * from test.test2"], False, [["id", "value"], ["id2", "value2"]], @@ -132,6 +186,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: False, True, "select * from test.test;select * from test.test2;", + None, ["select * from test.test", "select * from test.test2"], False, [["id", "value"], ["id2", "value2"]], @@ -147,6 +202,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: True, True, ["select * from test.test;"], + None, ["select * from test.test"], False, [["id", "value"]], @@ -159,6 +215,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: False, True, ["select * from test.test;"], + None, ["select * from test.test"], False, [["id", "value"]], @@ -171,6 +228,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: True, True, "select * from test.test;select * from test.test2;", + None, ["select * from test.test", "select * from test.test2"], False, [["id", "value"], ["id2", "value2"]], @@ -183,6 +241,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: False, True, "select * from test.test;select * from test.test2;", + None, ["select * from test.test", "select * from test.test2"], False, [["id", "value"], ["id2", "value2"]], @@ -198,6 +257,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: True, False, "select * from test.test", + None, ["select * from test.test"], True, [["id", "value"]], @@ -210,6 +270,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: True, False, "select * from test.test", + None, ["select * from test.test"], True, [["id", "value"]], @@ -222,6 +283,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: True, False, "select * from test.test", + None, ["select * from test.test"], True, [["id", "value"]], @@ -236,67 +298,51 @@ def test_query( return_last, split_statements, sql, + execution_timeout, cursor_calls, return_tuple, cursor_descriptions, cursor_results, hook_descriptions, hook_results, + mock_get_conn, + mock_get_requests, ): - with ( - patch("airflow.providers.databricks.hooks.databricks_sql.DatabricksSqlHook.get_conn") as mock_conn, - patch("airflow.providers.databricks.hooks.databricks_base.requests") as mock_requests, - ): - mock_requests.codes.ok = 200 - mock_requests.get.return_value.json.return_value = { - "endpoints": [ - { - "id": "1264e5078741679a", - "name": "Test", - "odbc_params": { - "hostname": "xx.cloud.databricks.com", - "path": "/sql/1.0/endpoints/1264e5078741679a", - }, - } - ] - } - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.get.return_value).status_code = status_code_mock - connections = [] - cursors = [] - for index in range(len(cursor_descriptions)): - conn = mock.MagicMock() - cur = mock.MagicMock( - rowcount=len(cursor_results[index]), - description=get_cursor_descriptions(cursor_descriptions[index]), - ) - cur.fetchall.return_value = cursor_results[index] - conn.cursor.return_value = cur - cursors.append(cur) - connections.append(conn) - mock_conn.side_effect = connections - - if not return_tuple: - with pytest.warns( - AirflowProviderDeprecationWarning, - match="""Returning a raw `databricks.sql.Row` object is deprecated. A namedtuple will be + connections = [] + cursors = [] + for index, cursor_description in enumerate(cursor_descriptions): + conn = mock.MagicMock() + cur = mock.MagicMock( + rowcount=len(cursor_results[index]), + description=get_cursor_descriptions(cursor_description), + ) + cur.fetchall.return_value = cursor_results[index] + conn.cursor.return_value = cur + cursors.append(cur) + connections.append(conn) + mock_get_conn.side_effect = connections + + if not return_tuple: + with pytest.warns( + AirflowProviderDeprecationWarning, + match="""Returning a raw `databricks.sql.Row` object is deprecated. A namedtuple will be returned instead in a future release of the databricks provider. Set `return_tuple=True` to enable this behavior.""", - ): - databricks_hook = DatabricksSqlHook(sql_endpoint_name="Test", return_tuple=return_tuple) - else: + ): databricks_hook = DatabricksSqlHook(sql_endpoint_name="Test", return_tuple=return_tuple) - results = databricks_hook.run( - sql=sql, handler=fetch_all_handler, return_last=return_last, split_statements=split_statements - ) + else: + databricks_hook = DatabricksSqlHook(sql_endpoint_name="Test", return_tuple=return_tuple) + results = databricks_hook.run( + sql=sql, handler=fetch_all_handler, return_last=return_last, split_statements=split_statements + ) - assert databricks_hook.descriptions == hook_descriptions - assert databricks_hook.last_description == hook_descriptions[-1] - assert results == hook_results + assert databricks_hook.descriptions == hook_descriptions + assert databricks_hook.last_description == hook_descriptions[-1] + assert results == hook_results - for index, cur in enumerate(cursors): - cur.execute.assert_has_calls([mock.call(cursor_calls[index])]) - cur.close.assert_called() + for index, cur in enumerate(cursors): + cur.execute.assert_has_calls([mock.call(cursor_calls[index])]) + cur.close.assert_called() @pytest.mark.parametrize( @@ -330,3 +376,82 @@ def test_incorrect_column_names(row_objects, fields_names): """ result = DatabricksSqlHook(return_tuple=True)._make_common_data_structure(row_objects) assert result._fields == fields_names + + +def test_execution_timeout_exceeded( + mock_get_conn, + mock_get_requests, + sql="select * from test.test", + execution_timeout=timedelta(microseconds=0), + cursor_descriptions=( + "id", + "value", + ), + cursor_results=( + Row(id=1, value=2), + Row(id=11, value=12), + ), +): + with patch( + "airflow.providers.databricks.hooks.databricks_sql.create_timeout_thread" + ) as mock_create_timeout_thread, patch.object(DatabricksSqlHook, "_run_command") as mock_run_command: + conn = mock.MagicMock() + cur = mock.MagicMock( + rowcount=len(cursor_results), + description=get_cursor_descriptions(cursor_descriptions), + ) + + # Simulate a timeout + mock_create_timeout_thread.return_value = threading.Timer(cur, execution_timeout) + + mock_run_command.side_effect = Exception("Mocked exception") + + cur.fetchall.return_value = cursor_results + conn.cursor.return_value = cur + mock_get_conn.side_effect = [conn] + + with pytest.raises(AirflowException) as exc_info: + DatabricksSqlHook(sql_endpoint_name="Test", return_tuple=True).run( + sql=sql, + execution_timeout=execution_timeout, + handler=fetch_all_handler, + ) + + assert "Timeout threshold exceeded" in str(exc_info.value) + + +def test_create_timeout_thread( + mock_get_conn, + mock_get_requests, + mock_timer, + cursor_descriptions=( + "id", + "value", + ), +): + cur = mock.MagicMock( + rowcount=1, + description=get_cursor_descriptions(cursor_descriptions), + ) + timeout = timedelta(seconds=1) + thread = create_timeout_thread(cur=cur, execution_timeout=timeout) + mock_timer.assert_called_once_with(timeout.total_seconds(), cur.connection.cancel) + assert thread is not None + + +def test_create_timeout_thread_no_timeout( + mock_get_conn, + mock_get_requests, + mock_timer, + cursor_descriptions=( + "id", + "value", + ), +): + cur = mock.MagicMock( + rowcount=1, + description=get_cursor_descriptions(cursor_descriptions), + ) + thread = create_timeout_thread(cur=cur, execution_timeout=None) + mock_timer.assert_not_called() + assert thread is None diff --git a/providers/tests/databricks/test_exceptions.py b/providers/tests/databricks/test_exceptions.py new file mode 100644 index 000000000000..17c42af47469 --- /dev/null +++ b/providers/tests/databricks/test_exceptions.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.databricks.exceptions import DatabricksSqlExecutionError, DatabricksSqlExecutionTimeout + + +def test_databricks_sql_execution_error(): + """Test if AirflowTaskExecutionError can be raised correctly.""" + with pytest.raises(DatabricksSqlExecutionError, match="Task execution failed"): + raise DatabricksSqlExecutionError("Task execution failed") + + +def test_databricks_sql_execution_timeout(): + """Test if AirflowTaskExecutionTimeout can be raised correctly.""" + with pytest.raises(DatabricksSqlExecutionTimeout, match="Task execution timed out"): + raise DatabricksSqlExecutionTimeout("Task execution timed out") diff --git a/providers/tests/edge/executors/__init__.py b/providers/tests/edge/executors/__init__.py index 217e5db96078..13a83393a912 100644 --- a/providers/tests/edge/executors/__init__.py +++ b/providers/tests/edge/executors/__init__.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information