diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index bc3051911d6b..922089d76e54 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -436,7 +436,7 @@ "aiohttp>=3.9.2, <4", "apache-airflow-providers-common-sql>=1.20.0", "apache-airflow>=2.8.0", - "databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0", + "databricks-sql-connector>=3.0.0", "mergedeep>=1.3.4", "pandas>=1.5.3,<2.2;python_version<\"3.9\"", "pandas>=2.1.2,<2.2;python_version>=\"3.9\"", diff --git a/providers/src/airflow/providers/databricks/hooks/databricks_sql.py b/providers/src/airflow/providers/databricks/hooks/databricks_sql.py index 6d4f679b2eed..05ef274230d9 100644 --- a/providers/src/airflow/providers/databricks/hooks/databricks_sql.py +++ b/providers/src/airflow/providers/databricks/hooks/databricks_sql.py @@ -27,7 +27,6 @@ Any, Callable, Iterable, - List, Mapping, Sequence, TypeVar, @@ -36,18 +35,20 @@ ) from databricks import sql # type: ignore[attr-defined] +from databricks.sql.types import Row from airflow.exceptions import ( AirflowException, AirflowProviderDeprecationWarning, ) +from airflow.models.connection import Connection as AirflowConnection from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results from airflow.providers.databricks.exceptions import DatabricksSqlExecutionError, DatabricksSqlExecutionTimeout from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook if TYPE_CHECKING: from databricks.sql.client import Connection - from databricks.sql.types import Row + LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints") @@ -106,7 +107,7 @@ def __init__( **kwargs, ) -> None: super().__init__(databricks_conn_id, caller=caller) - self._sql_conn = None + self._sql_conn: Connection | None = None self._token: str | None = None self._http_path = http_path self._sql_endpoint_name = sql_endpoint_name @@ -146,7 +147,7 @@ def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]: else: return endpoint - def get_conn(self) -> Connection: + def get_conn(self) -> AirflowConnection: """Return a Databricks SQL connection object.""" if not self._http_path: if self._sql_endpoint_name: @@ -161,20 +162,15 @@ def get_conn(self) -> Connection: "or sql_endpoint_name should be specified" ) - requires_init = True - if not self._token: - self._token = self._get_token(raise_error=True) - else: - new_token = self._get_token(raise_error=True) - if new_token != self._token: - self._token = new_token - else: - requires_init = False + prev_token = self._token + new_token = self._get_token(raise_error=True) + if not self._token or new_token != self._token: + self._token = new_token if not self.session_config: self.session_config = self.databricks_conn.extra_dejson.get("session_configuration") - if not self._sql_conn or requires_init: + if not self._sql_conn or prev_token != new_token: if self._sql_conn: # close already existing connection self._sql_conn.close() self._sql_conn = sql.connect( @@ -189,7 +185,10 @@ def get_conn(self) -> Connection: **self._get_extra_config(), **self.additional_params, ) - return self._sql_conn + + if self._sql_conn is None: + raise AirflowException("SQL connection is not initialized") + return cast(AirflowConnection, self._sql_conn) @overload # type: ignore[override] def run( @@ -310,22 +309,23 @@ def run( else: return results - def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple: + def _make_common_data_structure(self, result: T | Sequence[T]) -> tuple[Any, ...] | list[tuple[Any, ...]]: """Transform the databricks Row objects into namedtuple.""" # Below ignored lines respect namedtuple docstring, but mypy do not support dynamically # instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848 if isinstance(result, list): - rows: list[Row] = result + rows: Sequence[Row] = result if not rows: return [] rows_fields = tuple(rows[0].__fields__) rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore - return cast(List[tuple], [rows_object(*row) for row in rows]) - else: - row: Row = result - row_fields = tuple(row.__fields__) + return cast(list[tuple[Any, ...]], [rows_object(*row) for row in rows]) + elif isinstance(result, Row): + row_fields = tuple(result.__fields__) row_object = namedtuple("Row", row_fields, rename=True) # type: ignore - return cast(tuple, row_object(*row)) + return cast(tuple[Any, ...], row_object(*result)) + + raise TypeError(f"Expected Sequence[Row] or Row, but got {type(result)}") def bulk_dump(self, table, tmp_file): raise NotImplementedError() diff --git a/providers/src/airflow/providers/databricks/provider.yaml b/providers/src/airflow/providers/databricks/provider.yaml index 690b15c0ef9a..ef9d8a5dcc62 100644 --- a/providers/src/airflow/providers/databricks/provider.yaml +++ b/providers/src/airflow/providers/databricks/provider.yaml @@ -75,10 +75,7 @@ dependencies: - apache-airflow>=2.8.0 - apache-airflow-providers-common-sql>=1.20.0 - requests>=2.27.0,<3 - # The connector 2.9.0 released on Aug 10, 2023 has a bug that it does not properly declare urllib3 and - # it needs to be excluded. See https://github.com/databricks/databricks-sql-python/issues/190 - # The 2.9.1 (to be released soon) already contains the fix - - databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0 + - databricks-sql-connector>=3.0.0 - aiohttp>=3.9.2, <4 - mergedeep>=1.3.4 - pandas>=2.1.2,<2.2;python_version>="3.9" diff --git a/providers/tests/snowflake/operators/test_snowflake_sql.py b/providers/tests/snowflake/operators/test_snowflake_sql.py index fb1bcd172635..e6c6dac28165 100644 --- a/providers/tests/snowflake/operators/test_snowflake_sql.py +++ b/providers/tests/snowflake/operators/test_snowflake_sql.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from typing import Any, Callable from unittest import mock from unittest.mock import MagicMock, patch @@ -27,15 +28,19 @@ databricks = importorskip("databricks") +MockRow = None try: from databricks.sql.types import Row except ImportError: # Row is used in the parametrize so it's parsed during collection and we need to have a viable # replacement for the collection time when databricks is not installed (Python 3.12 for now) - def Row(*args, **kwargs): + def MockRow(*args: Any, **kwargs: Any) -> MagicMock: return MagicMock() +RowType: type[Row] | Callable[..., MagicMock] = Row if "Row" in locals() else MockRow + + from airflow.models.connection import Connection from airflow.providers.common.compat.openlineage.facet import ( ColumnLineageDatasetFacet, @@ -59,45 +64,45 @@ def Row(*args, **kwargs): "select * from dummy", True, True, - [Row(id=1, value="value1"), Row(id=2, value="value2")], + [RowType(id=1, value="value1"), RowType(id=2, value="value2")], [[("id",), ("value",)]], - ([Row(id=1, value="value1"), Row(id=2, value="value2")]), + ([RowType(id=1, value="value1"), RowType(id=2, value="value2")]), id="Scalar: Single SQL statement, return_last, split statement", ), pytest.param( "select * from dummy;select * from dummy2", True, True, - [Row(id=1, value="value1"), Row(id=2, value="value2")], + [RowType(id=1, value="value1"), RowType(id=2, value="value2")], [[("id",), ("value",)]], - ([Row(id=1, value="value1"), Row(id=2, value="value2")]), + ([RowType(id=1, value="value1"), RowType(id=2, value="value2")]), id="Scalar: Multiple SQL statements, return_last, split statement", ), pytest.param( "select * from dummy", False, False, - [Row(id=1, value="value1"), Row(id=2, value="value2")], + [RowType(id=1, value="value1"), RowType(id=2, value="value2")], [[("id",), ("value",)]], - ([Row(id=1, value="value1"), Row(id=2, value="value2")]), + ([RowType(id=1, value="value1"), RowType(id=2, value="value2")]), id="Scalar: Single SQL statements, no return_last (doesn't matter), no split statement", ), pytest.param( "select * from dummy", True, False, - [Row(id=1, value="value1"), Row(id=2, value="value2")], + [RowType(id=1, value="value1"), RowType(id=2, value="value2")], [[("id",), ("value",)]], - ([Row(id=1, value="value1"), Row(id=2, value="value2")]), + ([RowType(id=1, value="value1"), RowType(id=2, value="value2")]), id="Scalar: Single SQL statements, return_last (doesn't matter), no split statement", ), pytest.param( ["select * from dummy"], False, False, - [[Row(id=1, value="value1"), Row(id=2, value="value2")]], + [[RowType(id=1, value="value1"), RowType(id=2, value="value2")]], [[("id",), ("value",)]], - [([Row(id=1, value="value1"), Row(id=2, value="value2")])], + [([RowType(id=1, value="value1"), RowType(id=2, value="value2")])], id="Non-Scalar: Single SQL statements in list, no return_last, no split statement", ), pytest.param( @@ -105,13 +110,13 @@ def Row(*args, **kwargs): False, False, [ - [Row(id=1, value="value1"), Row(id=2, value="value2")], - [Row(id2=1, value2="value1"), Row(id2=2, value2="value2")], + [RowType(id=1, value="value1"), RowType(id=2, value="value2")], + [RowType(id2=1, value2="value1"), RowType(id2=2, value2="value2")], ], [[("id",), ("value",)], [("id2",), ("value2",)]], [ - ([Row(id=1, value="value1"), Row(id=2, value="value2")]), - ([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]), + ([RowType(id=1, value="value1"), RowType(id=2, value="value2")]), + ([RowType(id2=1, value2="value1"), RowType(id2=2, value2="value2")]), ], id="Non-Scalar: Multiple SQL statements in list, no return_last (no matter), no split statement", ), @@ -120,13 +125,13 @@ def Row(*args, **kwargs): True, False, [ - [Row(id=1, value="value1"), Row(id=2, value="value2")], - [Row(id2=1, value2="value1"), Row(id2=2, value2="value2")], + [RowType(id=1, value="value1"), RowType(id=2, value="value2")], + [RowType(id2=1, value2="value1"), RowType(id2=2, value2="value2")], ], [[("id",), ("value",)], [("id2",), ("value2",)]], [ - ([Row(id=1, value="value1"), Row(id=2, value="value2")]), - ([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]), + ([RowType(id=1, value="value1"), RowType(id=2, value="value2")]), + ([RowType(id2=1, value2="value1"), RowType(id2=2, value2="value2")]), ], id="Non-Scalar: Multiple SQL statements in list, return_last (no matter), no split statement", ),