Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade databricks provider dependency #43272

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@
"aiohttp>=3.9.2, <4",
"apache-airflow-providers-common-sql>=1.20.0",
"apache-airflow>=2.8.0",
"databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0",
"databricks-sql-connector>=3.0.0",
"mergedeep>=1.3.4",
"pandas>=1.5.3,<2.2;python_version<\"3.9\"",
"pandas>=2.1.2,<2.2;python_version>=\"3.9\"",
Expand Down
44 changes: 22 additions & 22 deletions providers/src/airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
Any,
Callable,
Iterable,
List,
Mapping,
Sequence,
TypeVar,
Expand All @@ -36,18 +35,20 @@
)

from databricks import sql # type: ignore[attr-defined]
from databricks.sql.types import Row

from airflow.exceptions import (
AirflowException,
AirflowProviderDeprecationWarning,
)
from airflow.models.connection import Connection as AirflowConnection
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
from airflow.providers.databricks.exceptions import DatabricksSqlExecutionError, DatabricksSqlExecutionTimeout
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

if TYPE_CHECKING:
from databricks.sql.client import Connection
from databricks.sql.types import Row


LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")

Expand Down Expand Up @@ -106,7 +107,7 @@ def __init__(
**kwargs,
) -> None:
super().__init__(databricks_conn_id, caller=caller)
self._sql_conn = None
self._sql_conn: Connection | None = None
self._token: str | None = None
self._http_path = http_path
self._sql_endpoint_name = sql_endpoint_name
Expand Down Expand Up @@ -146,7 +147,7 @@ def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]:
else:
return endpoint

def get_conn(self) -> Connection:
def get_conn(self) -> AirflowConnection:
"""Return a Databricks SQL connection object."""
if not self._http_path:
if self._sql_endpoint_name:
Expand All @@ -161,20 +162,15 @@ def get_conn(self) -> Connection:
"or sql_endpoint_name should be specified"
)

requires_init = True
if not self._token:
self._token = self._get_token(raise_error=True)
else:
new_token = self._get_token(raise_error=True)
if new_token != self._token:
self._token = new_token
else:
requires_init = False
prev_token = self._token
new_token = self._get_token(raise_error=True)
if not self._token or new_token != self._token:
self._token = new_token

if not self.session_config:
self.session_config = self.databricks_conn.extra_dejson.get("session_configuration")

if not self._sql_conn or requires_init:
if not self._sql_conn or prev_token != new_token:
if self._sql_conn: # close already existing connection
self._sql_conn.close()
self._sql_conn = sql.connect(
Expand All @@ -189,7 +185,10 @@ def get_conn(self) -> Connection:
**self._get_extra_config(),
**self.additional_params,
)
return self._sql_conn

if self._sql_conn is None:
raise AirflowException("SQL connection is not initialized")
return cast(AirflowConnection, self._sql_conn)

@overload # type: ignore[override]
def run(
Expand Down Expand Up @@ -310,22 +309,23 @@ def run(
else:
return results

def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple:
def _make_common_data_structure(self, result: T | Sequence[T]) -> tuple[Any, ...] | list[tuple[Any, ...]]:
"""Transform the databricks Row objects into namedtuple."""
# Below ignored lines respect namedtuple docstring, but mypy do not support dynamically
# instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848
if isinstance(result, list):
rows: list[Row] = result
rows: Sequence[Row] = result
if not rows:
return []
rows_fields = tuple(rows[0].__fields__)
rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore
return cast(List[tuple], [rows_object(*row) for row in rows])
else:
row: Row = result
row_fields = tuple(row.__fields__)
return cast(list[tuple[Any, ...]], [rows_object(*row) for row in rows])
elif isinstance(result, Row):
row_fields = tuple(result.__fields__)
row_object = namedtuple("Row", row_fields, rename=True) # type: ignore
return cast(tuple, row_object(*row))
return cast(tuple[Any, ...], row_object(*result))

raise TypeError(f"Expected Sequence[Row] or Row, but got {type(result)}")

def bulk_dump(self, table, tmp_file):
raise NotImplementedError()
Expand Down
5 changes: 1 addition & 4 deletions providers/src/airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-sql>=1.20.0
- requests>=2.27.0,<3
# The connector 2.9.0 released on Aug 10, 2023 has a bug that it does not properly declare urllib3 and
# it needs to be excluded. See https://github.com/databricks/databricks-sql-python/issues/190
# The 2.9.1 (to be released soon) already contains the fix
- databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0
- databricks-sql-connector>=3.0.0
- aiohttp>=3.9.2, <4
- mergedeep>=1.3.4
- pandas>=2.1.2,<2.2;python_version>="3.9"
Expand Down
43 changes: 24 additions & 19 deletions providers/tests/snowflake/operators/test_snowflake_sql.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are changes to snowflake tests related to this PR?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Snowflake test module is using the Row class from databricks.sql.types[0], which triggered the mypy checks on the module when trying to upgrade the databricks-sql-connecter package. I just went and patched the errors mypy was flagging, since I think refactoring the module to not use the Row class is probably it's own change there, but open to suggestions if this isn't the way we'd like to go.

[0] https://github.com/apache/airflow/pull/43272/files/175059dfe5cfb1f64297a8e26bcfe278edd7ebfb#diff-6fa19a213d772197807bb12413c4731e6b8a001a241f7f9009e152a1aac1f066R33

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait. What??
I am not near my laptop so I can't investigate this right now but this sounds very odd. We need to look into the commit that added it to see what was the reason for creating such odd coupling.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the original PR which introduces the import[0], and then this PR adds logic to mock the import later[1]. It looks to be testing the return of a SQLExecuteQueryOperator but the test_exec_success doesn't seem to be a Snowflake specific test?

[0] #28006
[1] #38074

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this is right.
Snowflake tests shuld not use databricks SDK

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I think this is a mistake someone used databricks row in the snowflake tests - they "match" as they represent the output of dbapi that has generally similar structure, but this should be, I think fixed to have either some generic Row structure used or creating of own Row class in the tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think refactoring the module to not use the Row class is probably it's own change there, but open to suggestions if this isn't the way we'd like to go.

Yes. That's the way to go. Likely what would be great is to add similar tests that we have in snowflake now to databricks tests if possible, using databricks row, but to change the snowflake one to not use the databricks row

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from typing import Any, Callable
from unittest import mock
from unittest.mock import MagicMock, patch

Expand All @@ -27,15 +28,19 @@

databricks = importorskip("databricks")

MockRow = None
try:
from databricks.sql.types import Row
except ImportError:
# Row is used in the parametrize so it's parsed during collection and we need to have a viable
# replacement for the collection time when databricks is not installed (Python 3.12 for now)
def Row(*args, **kwargs):
def MockRow(*args: Any, **kwargs: Any) -> MagicMock:
return MagicMock()


RowType: type[Row] | Callable[..., MagicMock] = Row if "Row" in locals() else MockRow


from airflow.models.connection import Connection
from airflow.providers.common.compat.openlineage.facet import (
ColumnLineageDatasetFacet,
Expand All @@ -59,59 +64,59 @@ def Row(*args, **kwargs):
"select * from dummy",
True,
True,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[RowType(id=1, value="value1"), RowType(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([RowType(id=1, value="value1"), RowType(id=2, value="value2")]),
id="Scalar: Single SQL statement, return_last, split statement",
),
pytest.param(
"select * from dummy;select * from dummy2",
True,
True,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[RowType(id=1, value="value1"), RowType(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([RowType(id=1, value="value1"), RowType(id=2, value="value2")]),
id="Scalar: Multiple SQL statements, return_last, split statement",
),
pytest.param(
"select * from dummy",
False,
False,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[RowType(id=1, value="value1"), RowType(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([RowType(id=1, value="value1"), RowType(id=2, value="value2")]),
id="Scalar: Single SQL statements, no return_last (doesn't matter), no split statement",
),
pytest.param(
"select * from dummy",
True,
False,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[RowType(id=1, value="value1"), RowType(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([RowType(id=1, value="value1"), RowType(id=2, value="value2")]),
id="Scalar: Single SQL statements, return_last (doesn't matter), no split statement",
),
pytest.param(
["select * from dummy"],
False,
False,
[[Row(id=1, value="value1"), Row(id=2, value="value2")]],
[[RowType(id=1, value="value1"), RowType(id=2, value="value2")]],
[[("id",), ("value",)]],
[([Row(id=1, value="value1"), Row(id=2, value="value2")])],
[([RowType(id=1, value="value1"), RowType(id=2, value="value2")])],
id="Non-Scalar: Single SQL statements in list, no return_last, no split statement",
),
pytest.param(
["select * from dummy", "select * from dummy2"],
False,
False,
[
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
[RowType(id=1, value="value1"), RowType(id=2, value="value2")],
[RowType(id2=1, value2="value1"), RowType(id2=2, value2="value2")],
],
[[("id",), ("value",)], [("id2",), ("value2",)]],
[
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]),
([RowType(id=1, value="value1"), RowType(id=2, value="value2")]),
([RowType(id2=1, value2="value1"), RowType(id2=2, value2="value2")]),
],
id="Non-Scalar: Multiple SQL statements in list, no return_last (no matter), no split statement",
),
Expand All @@ -120,13 +125,13 @@ def Row(*args, **kwargs):
True,
False,
[
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
[RowType(id=1, value="value1"), RowType(id=2, value="value2")],
[RowType(id2=1, value2="value1"), RowType(id2=2, value2="value2")],
],
[[("id",), ("value",)], [("id2",), ("value2",)]],
[
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]),
([RowType(id=1, value="value1"), RowType(id=2, value="value2")]),
([RowType(id2=1, value2="value1"), RowType(id2=2, value2="value2")]),
],
id="Non-Scalar: Multiple SQL statements in list, return_last (no matter), no split statement",
),
Expand Down
Loading