From f94e6edd570a587c78fa0f2bf45a15cd24fe7c5f Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:00:48 -0800 Subject: [PATCH 01/21] Trying lock approach for dependency management (#878) --- CHANGELOG.md | 6 + pyproject.toml | 7 ++ uv.lock | 306 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 319 insertions(+) create mode 100644 uv.lock diff --git a/CHANGELOG.md b/CHANGELOG.md index a4c107a9..98bdcc16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## dbt-databricks 1.9.2 (TBD) + +### Under the Hood + +- Switch to UV and locks for dependency management ([878](https://github.com/databricks/dbt-databricks/pull/878)) + ## dbt-databricks 1.9.1 (December 16, 2024) ### Features diff --git a/pyproject.toml b/pyproject.toml index d2f728d8..a42a8f4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,14 @@ check-sdist = [ "pip freeze | grep dbt-databricks", ] +[tool.hatch.env] +requires = ["hatch-pip-compile"] + [tool.hatch.envs.default] +type = "pip-compile" +pip-compile-resolver = "uv" +lock-filename = "uv.lock" +pip-compile-constraint = "default" dependencies = [ "dbt_common @ git+https://github.com/dbt-labs/dbt-common.git", "dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git@main", diff --git a/uv.lock b/uv.lock new file mode 100644 index 00000000..a5783157 --- /dev/null +++ b/uv.lock @@ -0,0 +1,306 @@ +# +# This file is autogenerated by hatch-pip-compile with Python 3.9 +# +# - dbt_common@ git+https://github.com/dbt-labs/dbt-common.git +# - dbt-adapters@ git+https://github.com/dbt-labs/dbt-adapters.git@main +# - dbt-core@ git+https://github.com/dbt-labs/dbt-core.git@main#subdirectory=core +# - dbt-tests-adapter@ git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-tests-adapter +# - dbt-spark@ git+https://github.com/dbt-labs/dbt-spark.git@main +# - pytest +# - pytest-xdist +# - pytest-dotenv +# - freezegun +# - mypy +# - pre-commit +# - ruff +# - types-requests +# - debugpy +# - pydantic<2,>=1.10.0 +# - databricks-sdk==0.17.0 +# - databricks-sql-connector<4.0.0,>=3.5.0 +# - dbt-adapters<2.0,>=1.7.0 +# - dbt-common<2.0,>=1.10.0 +# - dbt-core<2.0,>=1.8.7 +# - dbt-spark<2.0,>=1.8.0 +# - keyring>=23.13.0 +# - pydantic>=1.10.0 +# + +agate==1.9.1 + # via + # dbt-adapters + # dbt-common + # dbt-core +attrs==24.3.0 + # via + # jsonschema + # referencing +babel==2.16.0 + # via agate +backports-tarfile==1.2.0 + # via jaraco-context +cachetools==5.5.0 + # via google-auth +certifi==2024.12.14 + # via requests +cfgv==3.4.0 + # via pre-commit +charset-normalizer==3.4.0 + # via requests +click==8.1.7 + # via + # dbt-core + # dbt-semantic-interfaces +colorama==0.4.6 + # via dbt-common +daff==1.3.46 + # via dbt-core +databricks-sdk==0.17.0 + # via hatch.envs.default +databricks-sql-connector==3.6.0 + # via hatch.envs.default +dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git@e3964d76c1719baf5e3fe689d385aec1d8535d15 + # via + # hatch.envs.default + # dbt-core + # dbt-spark + # dbt-tests-adapter +dbt-common @ git+https://github.com/dbt-labs/dbt-common.git@c72ea7e3abf70ce632d30722036dd0b4afcaf330 + # via + # hatch.envs.default + # dbt-adapters + # dbt-core + # dbt-spark +dbt-core @ git+https://github.com/dbt-labs/dbt-core.git@6c61cb7f7adbdce8edec35a887d6c766a401e403#subdirectory=core + # via + # hatch.envs.default + # dbt-spark + # dbt-tests-adapter +dbt-extractor==0.5.1 + # via dbt-core +dbt-semantic-interfaces==0.8.3 + # via dbt-core +dbt-spark @ git+https://github.com/dbt-labs/dbt-spark.git@a38a288d7d3868c88313350f7d369223b0f03a05 + # via hatch.envs.default +dbt-tests-adapter @ git+https://github.com/dbt-labs/dbt-adapters.git@e3964d76c1719baf5e3fe689d385aec1d8535d15#subdirectory=dbt-tests-adapter + # via hatch.envs.default +debugpy==1.8.11 + # via hatch.envs.default +deepdiff==7.0.1 + # via dbt-common +distlib==0.3.9 + # via virtualenv +et-xmlfile==2.0.0 + # via openpyxl +exceptiongroup==1.2.2 + # via pytest +execnet==2.1.1 + # via pytest-xdist +filelock==3.16.1 + # via virtualenv +freezegun==1.5.1 + # via + # hatch.envs.default + # dbt-tests-adapter +google-auth==2.37.0 + # via databricks-sdk +identify==2.6.3 + # via pre-commit +idna==3.10 + # via requests +importlib-metadata==6.11.0 + # via + # dbt-semantic-interfaces + # keyring +iniconfig==2.0.0 + # via pytest +isodate==0.6.1 + # via + # agate + # dbt-common +jaraco-classes==3.4.0 + # via keyring +jaraco-context==6.0.1 + # via keyring +jaraco-functools==4.1.0 + # via keyring +jinja2==3.1.4 + # via + # dbt-common + # dbt-core + # dbt-semantic-interfaces +jsonschema==4.23.0 + # via + # dbt-common + # dbt-semantic-interfaces +jsonschema-specifications==2024.10.1 + # via jsonschema +keyring==25.5.0 + # via hatch.envs.default +leather==0.4.0 + # via agate +lz4==4.3.3 + # via databricks-sql-connector +markupsafe==3.0.2 + # via jinja2 +mashumaro==3.14 + # via + # dbt-adapters + # dbt-common + # dbt-core +more-itertools==10.5.0 + # via + # dbt-semantic-interfaces + # jaraco-classes + # jaraco-functools +msgpack==1.1.0 + # via mashumaro +mypy==1.13.0 + # via hatch.envs.default +mypy-extensions==1.0.0 + # via mypy +networkx==3.2.1 + # via dbt-core +nodeenv==1.9.1 + # via pre-commit +numpy==1.26.4 + # via + # databricks-sql-connector + # pandas + # pyarrow +oauthlib==3.2.2 + # via databricks-sql-connector +openpyxl==3.1.5 + # via databricks-sql-connector +ordered-set==4.1.0 + # via deepdiff +packaging==24.2 + # via + # dbt-core + # pytest +pandas==2.2.3 + # via databricks-sql-connector +parsedatetime==2.6 + # via agate +pathspec==0.12.1 + # via + # dbt-common + # dbt-core +platformdirs==4.3.6 + # via virtualenv +pluggy==1.5.0 + # via pytest +pre-commit==4.0.1 + # via hatch.envs.default +protobuf==5.29.1 + # via + # dbt-adapters + # dbt-common + # dbt-core +pyarrow==16.1.0 + # via databricks-sql-connector +pyasn1==0.6.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.1 + # via google-auth +pydantic==1.10.19 + # via + # hatch.envs.default + # dbt-semantic-interfaces +pytest==8.3.4 + # via + # hatch.envs.default + # pytest-dotenv + # pytest-xdist +pytest-dotenv==0.5.2 + # via hatch.envs.default +pytest-xdist==3.6.1 + # via hatch.envs.default +python-dateutil==2.9.0.post0 + # via + # dbt-common + # dbt-semantic-interfaces + # freezegun + # pandas +python-dotenv==1.0.1 + # via pytest-dotenv +python-slugify==8.0.4 + # via agate +pytimeparse==1.1.8 + # via agate +pytz==2024.2 + # via + # dbt-adapters + # dbt-core + # pandas +pyyaml==6.0.2 + # via + # dbt-core + # dbt-semantic-interfaces + # dbt-tests-adapter + # pre-commit +referencing==0.35.1 + # via + # jsonschema + # jsonschema-specifications +requests==2.32.3 + # via + # databricks-sdk + # databricks-sql-connector + # dbt-common + # dbt-core + # snowplow-tracker +rpds-py==0.22.3 + # via + # jsonschema + # referencing +rsa==4.9 + # via google-auth +ruff==0.8.3 + # via hatch.envs.default +six==1.17.0 + # via + # isodate + # python-dateutil + # thrift +snowplow-tracker==1.0.4 + # via dbt-core +sqlparams==6.1.0 + # via dbt-spark +sqlparse==0.5.3 + # via dbt-core +text-unidecode==1.3 + # via python-slugify +thrift==0.20.0 + # via databricks-sql-connector +tomli==2.2.1 + # via + # mypy + # pytest +types-requests==2.32.0.20241016 + # via + # hatch.envs.default + # snowplow-tracker +typing-extensions==4.12.2 + # via + # dbt-adapters + # dbt-common + # dbt-core + # dbt-semantic-interfaces + # mashumaro + # mypy + # pydantic + # snowplow-tracker +tzdata==2024.2 + # via pandas +urllib3==2.2.3 + # via + # databricks-sql-connector + # requests + # types-requests +virtualenv==20.28.0 + # via pre-commit +zipp==3.21.0 + # via importlib-metadata From 477b74582f2758847d1fa7e2b2968c7b32e8220e Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Tue, 17 Dec 2024 09:16:47 -0800 Subject: [PATCH 02/21] Revert "Trying lock approach for dependency management (#878)" This reverts commit f94e6edd570a587c78fa0f2bf45a15cd24fe7c5f. --- CHANGELOG.md | 6 - pyproject.toml | 7 -- uv.lock | 306 ------------------------------------------------- 3 files changed, 319 deletions(-) delete mode 100644 uv.lock diff --git a/CHANGELOG.md b/CHANGELOG.md index 98bdcc16..a4c107a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,3 @@ -## dbt-databricks 1.9.2 (TBD) - -### Under the Hood - -- Switch to UV and locks for dependency management ([878](https://github.com/databricks/dbt-databricks/pull/878)) - ## dbt-databricks 1.9.1 (December 16, 2024) ### Features diff --git a/pyproject.toml b/pyproject.toml index a42a8f4e..d2f728d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,14 +62,7 @@ check-sdist = [ "pip freeze | grep dbt-databricks", ] -[tool.hatch.env] -requires = ["hatch-pip-compile"] - [tool.hatch.envs.default] -type = "pip-compile" -pip-compile-resolver = "uv" -lock-filename = "uv.lock" -pip-compile-constraint = "default" dependencies = [ "dbt_common @ git+https://github.com/dbt-labs/dbt-common.git", "dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git@main", diff --git a/uv.lock b/uv.lock deleted file mode 100644 index a5783157..00000000 --- a/uv.lock +++ /dev/null @@ -1,306 +0,0 @@ -# -# This file is autogenerated by hatch-pip-compile with Python 3.9 -# -# - dbt_common@ git+https://github.com/dbt-labs/dbt-common.git -# - dbt-adapters@ git+https://github.com/dbt-labs/dbt-adapters.git@main -# - dbt-core@ git+https://github.com/dbt-labs/dbt-core.git@main#subdirectory=core -# - dbt-tests-adapter@ git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-tests-adapter -# - dbt-spark@ git+https://github.com/dbt-labs/dbt-spark.git@main -# - pytest -# - pytest-xdist -# - pytest-dotenv -# - freezegun -# - mypy -# - pre-commit -# - ruff -# - types-requests -# - debugpy -# - pydantic<2,>=1.10.0 -# - databricks-sdk==0.17.0 -# - databricks-sql-connector<4.0.0,>=3.5.0 -# - dbt-adapters<2.0,>=1.7.0 -# - dbt-common<2.0,>=1.10.0 -# - dbt-core<2.0,>=1.8.7 -# - dbt-spark<2.0,>=1.8.0 -# - keyring>=23.13.0 -# - pydantic>=1.10.0 -# - -agate==1.9.1 - # via - # dbt-adapters - # dbt-common - # dbt-core -attrs==24.3.0 - # via - # jsonschema - # referencing -babel==2.16.0 - # via agate -backports-tarfile==1.2.0 - # via jaraco-context -cachetools==5.5.0 - # via google-auth -certifi==2024.12.14 - # via requests -cfgv==3.4.0 - # via pre-commit -charset-normalizer==3.4.0 - # via requests -click==8.1.7 - # via - # dbt-core - # dbt-semantic-interfaces -colorama==0.4.6 - # via dbt-common -daff==1.3.46 - # via dbt-core -databricks-sdk==0.17.0 - # via hatch.envs.default -databricks-sql-connector==3.6.0 - # via hatch.envs.default -dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git@e3964d76c1719baf5e3fe689d385aec1d8535d15 - # via - # hatch.envs.default - # dbt-core - # dbt-spark - # dbt-tests-adapter -dbt-common @ git+https://github.com/dbt-labs/dbt-common.git@c72ea7e3abf70ce632d30722036dd0b4afcaf330 - # via - # hatch.envs.default - # dbt-adapters - # dbt-core - # dbt-spark -dbt-core @ git+https://github.com/dbt-labs/dbt-core.git@6c61cb7f7adbdce8edec35a887d6c766a401e403#subdirectory=core - # via - # hatch.envs.default - # dbt-spark - # dbt-tests-adapter -dbt-extractor==0.5.1 - # via dbt-core -dbt-semantic-interfaces==0.8.3 - # via dbt-core -dbt-spark @ git+https://github.com/dbt-labs/dbt-spark.git@a38a288d7d3868c88313350f7d369223b0f03a05 - # via hatch.envs.default -dbt-tests-adapter @ git+https://github.com/dbt-labs/dbt-adapters.git@e3964d76c1719baf5e3fe689d385aec1d8535d15#subdirectory=dbt-tests-adapter - # via hatch.envs.default -debugpy==1.8.11 - # via hatch.envs.default -deepdiff==7.0.1 - # via dbt-common -distlib==0.3.9 - # via virtualenv -et-xmlfile==2.0.0 - # via openpyxl -exceptiongroup==1.2.2 - # via pytest -execnet==2.1.1 - # via pytest-xdist -filelock==3.16.1 - # via virtualenv -freezegun==1.5.1 - # via - # hatch.envs.default - # dbt-tests-adapter -google-auth==2.37.0 - # via databricks-sdk -identify==2.6.3 - # via pre-commit -idna==3.10 - # via requests -importlib-metadata==6.11.0 - # via - # dbt-semantic-interfaces - # keyring -iniconfig==2.0.0 - # via pytest -isodate==0.6.1 - # via - # agate - # dbt-common -jaraco-classes==3.4.0 - # via keyring -jaraco-context==6.0.1 - # via keyring -jaraco-functools==4.1.0 - # via keyring -jinja2==3.1.4 - # via - # dbt-common - # dbt-core - # dbt-semantic-interfaces -jsonschema==4.23.0 - # via - # dbt-common - # dbt-semantic-interfaces -jsonschema-specifications==2024.10.1 - # via jsonschema -keyring==25.5.0 - # via hatch.envs.default -leather==0.4.0 - # via agate -lz4==4.3.3 - # via databricks-sql-connector -markupsafe==3.0.2 - # via jinja2 -mashumaro==3.14 - # via - # dbt-adapters - # dbt-common - # dbt-core -more-itertools==10.5.0 - # via - # dbt-semantic-interfaces - # jaraco-classes - # jaraco-functools -msgpack==1.1.0 - # via mashumaro -mypy==1.13.0 - # via hatch.envs.default -mypy-extensions==1.0.0 - # via mypy -networkx==3.2.1 - # via dbt-core -nodeenv==1.9.1 - # via pre-commit -numpy==1.26.4 - # via - # databricks-sql-connector - # pandas - # pyarrow -oauthlib==3.2.2 - # via databricks-sql-connector -openpyxl==3.1.5 - # via databricks-sql-connector -ordered-set==4.1.0 - # via deepdiff -packaging==24.2 - # via - # dbt-core - # pytest -pandas==2.2.3 - # via databricks-sql-connector -parsedatetime==2.6 - # via agate -pathspec==0.12.1 - # via - # dbt-common - # dbt-core -platformdirs==4.3.6 - # via virtualenv -pluggy==1.5.0 - # via pytest -pre-commit==4.0.1 - # via hatch.envs.default -protobuf==5.29.1 - # via - # dbt-adapters - # dbt-common - # dbt-core -pyarrow==16.1.0 - # via databricks-sql-connector -pyasn1==0.6.1 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.4.1 - # via google-auth -pydantic==1.10.19 - # via - # hatch.envs.default - # dbt-semantic-interfaces -pytest==8.3.4 - # via - # hatch.envs.default - # pytest-dotenv - # pytest-xdist -pytest-dotenv==0.5.2 - # via hatch.envs.default -pytest-xdist==3.6.1 - # via hatch.envs.default -python-dateutil==2.9.0.post0 - # via - # dbt-common - # dbt-semantic-interfaces - # freezegun - # pandas -python-dotenv==1.0.1 - # via pytest-dotenv -python-slugify==8.0.4 - # via agate -pytimeparse==1.1.8 - # via agate -pytz==2024.2 - # via - # dbt-adapters - # dbt-core - # pandas -pyyaml==6.0.2 - # via - # dbt-core - # dbt-semantic-interfaces - # dbt-tests-adapter - # pre-commit -referencing==0.35.1 - # via - # jsonschema - # jsonschema-specifications -requests==2.32.3 - # via - # databricks-sdk - # databricks-sql-connector - # dbt-common - # dbt-core - # snowplow-tracker -rpds-py==0.22.3 - # via - # jsonschema - # referencing -rsa==4.9 - # via google-auth -ruff==0.8.3 - # via hatch.envs.default -six==1.17.0 - # via - # isodate - # python-dateutil - # thrift -snowplow-tracker==1.0.4 - # via dbt-core -sqlparams==6.1.0 - # via dbt-spark -sqlparse==0.5.3 - # via dbt-core -text-unidecode==1.3 - # via python-slugify -thrift==0.20.0 - # via databricks-sql-connector -tomli==2.2.1 - # via - # mypy - # pytest -types-requests==2.32.0.20241016 - # via - # hatch.envs.default - # snowplow-tracker -typing-extensions==4.12.2 - # via - # dbt-adapters - # dbt-common - # dbt-core - # dbt-semantic-interfaces - # mashumaro - # mypy - # pydantic - # snowplow-tracker -tzdata==2024.2 - # via pandas -urllib3==2.2.3 - # via - # databricks-sql-connector - # requests - # types-requests -virtualenv==20.28.0 - # via pre-commit -zipp==3.21.0 - # via importlib-metadata From 5f6412d71c12a31ba47446f1e4ec3d642691616a Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:03:34 -0800 Subject: [PATCH 03/21] Refactor reading env vars (#888) --- CHANGELOG.md | 6 +++ dbt/adapters/databricks/connections.py | 16 +++---- dbt/adapters/databricks/credentials.py | 8 ++-- dbt/adapters/databricks/global_state.py | 58 +++++++++++++++++++++++++ dbt/adapters/databricks/impl.py | 6 +-- dbt/adapters/databricks/logging.py | 4 +- tests/unit/test_adapter.py | 33 +++++++++----- 7 files changed, 102 insertions(+), 29 deletions(-) create mode 100644 dbt/adapters/databricks/global_state.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a4c107a9..59af816c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## dbt-databricks 1.9.2 (TBD) + +### Under the Hood + +- Refactor global state reading ([888](https://github.com/databricks/dbt-databricks/pull/888)) + ## dbt-databricks 1.9.1 (December 16, 2024) ### Features diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 509686d7..0b523574 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -59,6 +59,7 @@ CursorCreate, ) from dbt.adapters.databricks.events.other_events import QueryError +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker from dbt.adapters.databricks.utils import redact_credentials @@ -86,9 +87,6 @@ DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)") -# toggle for session managements that minimizes the number of sessions opened/closed -USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" - # Number of idle seconds before a connection is automatically closed. Only applicable if # USE_LONG_SESSIONS is true. # Updated when idle times of 180s were causing errors @@ -475,6 +473,8 @@ def add_query( auto_begin: bool = True, bindings: Optional[Any] = None, abridge_sql_log: bool = False, + retryable_exceptions: tuple[type[Exception], ...] = tuple(), + retry_limit: int = 1, *, close_cursor: bool = False, ) -> tuple[Connection, Any]: @@ -707,7 +707,7 @@ def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterRe class ExtendedSessionConnectionManager(DatabricksConnectionManager): def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None: assert ( - USE_LONG_SESSIONS + GlobalState.get_use_long_sessions() ), "This connection manager should only be used when USE_LONG_SESSIONS is enabled" super().__init__(profile, mp_context) self.threads_compute_connections: dict[ @@ -910,7 +910,7 @@ def open(cls, connection: Connection) -> Connection: # Once long session management is no longer under the USE_LONG_SESSIONS toggle # this should be renamed and replace the _open class method. assert ( - USE_LONG_SESSIONS + GlobalState.get_use_long_sessions() ), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS" databricks_connection = cast(DatabricksDBTConnection, connection) @@ -1013,7 +1013,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O # If there is no node we return the http_path for the default compute. if not query_header_context: - if not USE_LONG_SESSIONS: + if not GlobalState.get_use_long_sessions(): logger.debug(f"Thread {thread_id}: using default compute resource.") return creds.http_path @@ -1021,7 +1021,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O # If none is specified return the http_path for the default compute. compute_name = _get_compute_name(query_header_context) if not compute_name: - if not USE_LONG_SESSIONS: + if not GlobalState.get_use_long_sessions(): logger.debug(f"On thread {thread_id}: {relation_name} using default compute resource.") return creds.http_path @@ -1037,7 +1037,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O f"does not specify http_path, relation: {relation_name}" ) - if not USE_LONG_SESSIONS: + if not GlobalState.get_use_long_sessions(): logger.debug( f"On thread {thread_id}: {relation_name} using compute resource '{compute_name}'." ) diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 7a318cad..387d0e76 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -19,10 +19,10 @@ CredentialSaveError, CredentialShardEvent, ) +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" -DBT_DATABRICKS_INVOCATION_ENV = "DBT_DATABRICKS_INVOCATION_ENV" DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$") EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)") DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS" @@ -150,7 +150,7 @@ def validate_creds(self) -> None: @classmethod def get_invocation_env(cls) -> Optional[str]: - invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV) + invocation_env = GlobalState.get_invocation_env() if invocation_env: # Thrift doesn't allow nested () so we need to ensure # that the passed user agent is valid. @@ -160,9 +160,7 @@ def get_invocation_env(cls) -> Optional[str]: @classmethod def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]: - http_session_headers_str: Optional[str] = os.environ.get( - DBT_DATABRICKS_HTTP_SESSION_HEADERS - ) + http_session_headers_str = GlobalState.get_http_session_headers() http_session_headers_dict: dict[str, str] = ( { diff --git a/dbt/adapters/databricks/global_state.py b/dbt/adapters/databricks/global_state.py new file mode 100644 index 00000000..de240d39 --- /dev/null +++ b/dbt/adapters/databricks/global_state.py @@ -0,0 +1,58 @@ +import os +from typing import ClassVar, Optional + + +class GlobalState: + """Global state is a bad idea, but since we don't control instantiation, better to have it in a + single place than scattered throughout the codebase. + """ + + __use_long_sessions: ClassVar[Optional[bool]] = None + + @classmethod + def get_use_long_sessions(cls) -> bool: + if cls.__use_long_sessions is None: + cls.__use_long_sessions = ( + os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" + ) + return cls.__use_long_sessions + + __invocation_env: ClassVar[Optional[str]] = None + __invocation_env_set: ClassVar[bool] = False + + @classmethod + def get_invocation_env(cls) -> Optional[str]: + if not cls.__invocation_env_set: + cls.__invocation_env = os.getenv("DBT_DATABRICKS_INVOCATION_ENV") + cls.__invocation_env_set = True + return cls.__invocation_env + + __session_headers: ClassVar[Optional[str]] = None + __session_headers_set: ClassVar[bool] = False + + @classmethod + def get_http_session_headers(cls) -> Optional[str]: + if not cls.__session_headers_set: + cls.__session_headers = os.getenv("DBT_DATABRICKS_HTTP_SESSION_HEADERS") + cls.__session_headers_set = True + return cls.__session_headers + + __describe_char_bypass: ClassVar[Optional[bool]] = None + + @classmethod + def get_char_limit_bypass(cls) -> bool: + if cls.__describe_char_bypass is None: + cls.__describe_char_bypass = ( + os.getenv("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE" + ) + return cls.__describe_char_bypass + + __connector_log_level: ClassVar[Optional[str]] = None + + @classmethod + def get_connector_log_level(cls) -> str: + if cls.__connector_log_level is None: + cls.__connector_log_level = os.getenv( + "DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN" + ).upper() + return cls.__connector_log_level diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index dce432c9..15c333e2 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -32,10 +32,10 @@ ) from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.connections import ( - USE_LONG_SESSIONS, DatabricksConnectionManager, ExtendedSessionConnectionManager, ) +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.python_models.python_submissions import ( AllPurposeClusterPythonJobHelper, JobClusterPythonJobHelper, @@ -142,7 +142,7 @@ def get_identifier_list_string(table_names: set[str]) -> str: """ _identifier = "|".join(table_names) - bypass_2048_char_limit = os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "false") + bypass_2048_char_limit = GlobalState.get_char_limit_bypass() if bypass_2048_char_limit == "true": _identifier = _identifier if len(_identifier) < 2048 else "*" return _identifier @@ -154,7 +154,7 @@ class DatabricksAdapter(SparkAdapter): Relation = DatabricksRelation Column = DatabricksColumn - if USE_LONG_SESSIONS: + if GlobalState.get_use_long_sessions(): ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager else: ConnectionManager = DatabricksConnectionManager diff --git a/dbt/adapters/databricks/logging.py b/dbt/adapters/databricks/logging.py index d0f1d42b..81e7449e 100644 --- a/dbt/adapters/databricks/logging.py +++ b/dbt/adapters/databricks/logging.py @@ -1,7 +1,7 @@ -import os from logging import Handler, LogRecord, getLogger from typing import Union +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.events.logging import AdapterLogger logger = AdapterLogger("Databricks") @@ -22,7 +22,7 @@ def emit(self, record: LogRecord) -> None: dbt_adapter_logger = AdapterLogger("databricks-sql-connector") pysql_logger = getLogger("databricks.sql") -pysql_logger_level = os.environ.get("DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN").upper() +pysql_logger_level = GlobalState.get_connector_log_level() pysql_logger.setLevel(pysql_logger_level) pysql_handler = DbtCoreHandler(dbt_logger=dbt_adapter_logger, level=pysql_logger_level) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 78ae12cb..d42fa5e1 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -11,8 +11,6 @@ from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.credentials import ( CATALOG_KEY_IN_SESSION_PROPERTIES, - DBT_DATABRICKS_HTTP_SESSION_HEADERS, - DBT_DATABRICKS_INVOCATION_ENV, ) from dbt.adapters.databricks.impl import get_identifier_list_string from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType @@ -114,7 +112,10 @@ def test_invalid_custom_user_agent(self): with pytest.raises(DbtValidationError) as excinfo: config = self._get_config() adapter = DatabricksAdapter(config, get_context("spawn")) - with patch.dict("os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "(Some-thing)"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_invocation_env", + return_value="(Some-thing)", + ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -128,8 +129,9 @@ def test_custom_user_agent(self): "dbt.adapters.databricks.connections.dbsql.connect", new=self._connect_func(expected_invocation_env="databricks-workflows"), ): - with patch.dict( - "os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "databricks-workflows"} + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_invocation_env", + return_value="databricks-workflows", ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -190,9 +192,9 @@ def _test_environment_http_headers( "dbt.adapters.databricks.connections.dbsql.connect", new=self._connect_func(expected_http_headers=expected_http_headers), ): - with patch.dict( - "os.environ", - **{DBT_DATABRICKS_HTTP_SESSION_HEADERS: http_headers_str}, + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_http_session_headers", + return_value=http_headers_str, ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -912,7 +914,10 @@ def test_describe_table_extended_2048_char_limit(self): assert get_identifier_list_string(table_names) == "|".join(table_names) # If environment variable is set, then limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # Long list of table names is capped assert get_identifier_list_string(table_names) == "*" @@ -941,7 +946,10 @@ def test_describe_table_extended_should_limit(self): table_names = set([f"customers_{i}" for i in range(200)]) # If environment variable is set, then limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # Long list of table names is capped assert get_identifier_list_string(table_names) == "*" @@ -954,7 +962,10 @@ def test_describe_table_extended_may_limit(self): table_names = set([f"customers_{i}" for i in range(200)]) # If environment variable is set, then we may limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # But a short list of table names is not capped assert get_identifier_list_string(list(table_names)[:5]) == "|".join( list(table_names)[:5] From b0ff51bd7198d456681a5b1ca822e5adba34b732 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Fri, 20 Dec 2024 13:53:55 -0800 Subject: [PATCH 04/21] Use single quotes in gets in templates (#889) --- .../macros/materializations/seeds/helpers.sql | 2 +- .../macros/relations/constraints.sql | 34 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/dbt/include/databricks/macros/materializations/seeds/helpers.sql b/dbt/include/databricks/macros/materializations/seeds/helpers.sql index df690f18..82acaba3 100644 --- a/dbt/include/databricks/macros/materializations/seeds/helpers.sql +++ b/dbt/include/databricks/macros/materializations/seeds/helpers.sql @@ -6,7 +6,7 @@ {% set batch_size = get_batch_size() %} {% set column_override = model['config'].get('column_types', {}) %} - {% set must_cast = model['config'].get("file_format", "delta") == "parquet" %} + {% set must_cast = model['config'].get('file_format', 'delta') == 'parquet' %} {% set statements = [] %} diff --git a/dbt/include/databricks/macros/relations/constraints.sql b/dbt/include/databricks/macros/relations/constraints.sql index 68f3a44f..34d5b415 100644 --- a/dbt/include/databricks/macros/relations/constraints.sql +++ b/dbt/include/databricks/macros/relations/constraints.sql @@ -106,19 +106,19 @@ {% macro get_constraint_sql(relation, constraint, model, column={}) %} {% set statements = [] %} - {% set type = constraint.get("type", "") %} + {% set type = constraint.get('type', '') %} {% if type == 'check' %} - {% set expression = constraint.get("expression", "") %} + {% set expression = constraint.get('expression', '') %} {% if not expression %} {{ exceptions.raise_compiler_error('Invalid check constraint expression') }} {% endif %} - {% set name = constraint.get("name") %} + {% set name = constraint.get('name') %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} - {%- set name = local_md5 (relation.identifier ~ ";" ~ column.get("name", "") ~ ";" ~ expression ~ ";") -%} + {%- set name = local_md5 (relation.identifier ~ ";" ~ column.get('name', '') ~ ";" ~ expression ~ ";") -%} {% else %} {{ exceptions.raise_compiler_error("Constraint of type " ~ type ~ " with no `name` provided, and no md5 utility.") }} {% endif %} @@ -126,7 +126,7 @@ {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " check (" ~ expression ~ ");" %} {% do statements.append(stmt) %} {% elif type == 'not_null' %} - {% set column_names = constraint.get("columns", []) %} + {% set column_names = constraint.get('columns', []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} @@ -144,7 +144,7 @@ {% if constraint.get('warn_unenforced') %} {{ exceptions.warn("unenforced constraint type: " ~ type)}} {% endif %} - {% set column_names = constraint.get("columns", []) %} + {% set column_names = constraint.get('columns', []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} @@ -161,7 +161,7 @@ {% set joined_names = quoted_names|join(", ") %} - {% set name = constraint.get("name") %} + {% set name = constraint.get('name') %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} @@ -178,7 +178,7 @@ {{ exceptions.warn("unenforced constraint type: " ~ constraint.type)}} {% endif %} - {% set name = constraint.get("name") %} + {% set name = constraint.get('name') %} {% if constraint.get('expression') %} @@ -193,7 +193,7 @@ {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key" ~ constraint.get('expression') %} {% else %} - {% set column_names = constraint.get("columns", []) %} + {% set column_names = constraint.get('columns', []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} @@ -210,7 +210,7 @@ {% set joined_names = quoted_names|join(", ") %} - {% set parent = constraint.get("to") %} + {% set parent = constraint.get('to') %} {% if not parent %} {{ exceptions.raise_compiler_error('No parent table defined for foreign key: ' ~ expression) }} {% endif %} @@ -228,7 +228,7 @@ {% endif %} {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key(" ~ joined_names ~ ") references " ~ parent %} - {% set parent_columns = constraint.get("to_columns") %} + {% set parent_columns = constraint.get('to_columns') %} {% if parent_columns %} {% set stmt = stmt ~ "(" ~ parent_columns|join(", ") ~ ")"%} {% endif %} @@ -236,13 +236,13 @@ {% set stmt = stmt ~ ";" %} {% do statements.append(stmt) %} {% elif type == 'custom' %} - {% set expression = constraint.get("expression", "") %} + {% set expression = constraint.get('expression', '') %} {% if not expression %} {{ exceptions.raise_compiler_error('Missing custom constraint expression') }} {% endif %} - {% set name = constraint.get("name") %} - {% set expression = constraint.get("expression") %} + {% set name = constraint.get('name') %} + {% set expression = constraint.get('expression') %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} @@ -264,15 +264,15 @@ {# convert constraints defined using the original databricks format #} {% set dbt_constraints = [] %} {% for constraint in constraints %} - {% if constraint.get and constraint.get("type") %} + {% if constraint.get and constraint.get('type') %} {# already in model contract format #} {% do dbt_constraints.append(constraint) %} {% else %} {% if column %} {% if constraint == "not_null" %} - {% do dbt_constraints.append({"type": "not_null", "columns": [column.get("name")]}) %} + {% do dbt_constraints.append({"type": "not_null", "columns": [column.get('name')]}) %} {% else %} - {{ exceptions.raise_compiler_error('Invalid constraint for column ' ~ column.get("name", "") ~ '. Only `not_null` is supported.') }} + {{ exceptions.raise_compiler_error('Invalid constraint for column ' ~ column.get('name', "") ~ '. Only `not_null` is supported.') }} {% endif %} {% else %} {% set name = constraint['name'] %} From 3fb54dc65641c2f7b2be1004ef6052dd471e4dec Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Wed, 15 Jan 2025 09:58:02 -0800 Subject: [PATCH 05/21] Switch to render for safest relation str substitution (#903) --- CHANGELOG.md | 1 + dbt/include/databricks/macros/adapters/columns.sql | 4 ++-- .../databricks/macros/adapters/persist_docs.sql | 4 ++-- .../databricks/macros/relations/constraints.sql | 10 +++++----- .../databricks/macros/relations/liquid_clustering.sql | 2 +- .../macros/relations/materialized_view/alter.sql | 2 +- .../macros/relations/materialized_view/drop.sql | 2 +- .../macros/relations/materialized_view/refresh.sql | 2 +- .../macros/relations/streaming_table/drop.sql | 2 +- .../macros/relations/streaming_table/refresh.sql | 2 +- .../databricks/macros/relations/table/create.sql | 4 ++-- dbt/include/databricks/macros/relations/table/drop.sql | 2 +- dbt/include/databricks/macros/relations/tags.sql | 4 ++-- .../databricks/macros/relations/tblproperties.sql | 2 +- .../databricks/macros/relations/view/create.sql | 2 +- dbt/include/databricks/macros/relations/view/drop.sql | 2 +- 16 files changed, 24 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 59af816c..c01efeeb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ### Under the Hood - Refactor global state reading ([888](https://github.com/databricks/dbt-databricks/pull/888)) +- Switch to relation.render() for string interpolation ([903](https://github.com/databricks/dbt-databricks/pull/903)) ## dbt-databricks 1.9.1 (December 16, 2024) diff --git a/dbt/include/databricks/macros/adapters/columns.sql b/dbt/include/databricks/macros/adapters/columns.sql index 7fe40e6f..d9b041cc 100644 --- a/dbt/include/databricks/macros/adapters/columns.sql +++ b/dbt/include/databricks/macros/adapters/columns.sql @@ -32,13 +32,13 @@ {{ exceptions.raise_compiler_error('Delta format required for dropping columns from tables') }} {% endif %} {%- call statement('alter_relation_remove_columns') -%} - ALTER TABLE {{ relation }} DROP COLUMNS ({{ api.Column.format_remove_column_list(remove_columns) }}) + ALTER TABLE {{ relation.render() }} DROP COLUMNS ({{ api.Column.format_remove_column_list(remove_columns) }}) {%- endcall -%} {% endif %} {% if add_columns %} {%- call statement('alter_relation_add_columns') -%} - ALTER TABLE {{ relation }} ADD COLUMNS ({{ api.Column.format_add_column_list(add_columns) }}) + ALTER TABLE {{ relation.render() }} ADD COLUMNS ({{ api.Column.format_add_column_list(add_columns) }}) {%- endcall -%} {% endif %} {% endmacro %} \ No newline at end of file diff --git a/dbt/include/databricks/macros/adapters/persist_docs.sql b/dbt/include/databricks/macros/adapters/persist_docs.sql index 873039e8..a8ad48ba 100644 --- a/dbt/include/databricks/macros/adapters/persist_docs.sql +++ b/dbt/include/databricks/macros/adapters/persist_docs.sql @@ -4,7 +4,7 @@ {% set comment = column['description'] %} {% set escaped_comment = comment | replace('\'', '\\\'') %} {% set comment_query %} - alter table {{ relation }} change column {{ api.Column.get_name(column) }} comment '{{ escaped_comment }}'; + alter table {{ relation.render()|lower }} change column {{ api.Column.get_name(column) }} comment '{{ escaped_comment }}'; {% endset %} {% do run_query(comment_query) %} {% endfor %} @@ -13,7 +13,7 @@ {% macro alter_table_comment(relation, model) %} {% set comment_query %} - comment on table {{ relation|lower }} is '{{ model.description | replace("'", "\\'") }}' + comment on table {{ relation.render()|lower }} is '{{ model.description | replace("'", "\\'") }}' {% endset %} {% do run_query(comment_query) %} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/constraints.sql b/dbt/include/databricks/macros/relations/constraints.sql index 34d5b415..6d999823 100644 --- a/dbt/include/databricks/macros/relations/constraints.sql +++ b/dbt/include/databricks/macros/relations/constraints.sql @@ -134,7 +134,7 @@ {% set column = model.get('columns', {}).get(column_name) %} {% if column %} {% set quoted_name = api.Column.get_name(column) %} - {% set stmt = "alter table " ~ relation ~ " change column " ~ quoted_name ~ " set not null " ~ (constraint.expression or "") ~ ";" %} + {% set stmt = "alter table " ~ relation.render() ~ " change column " ~ quoted_name ~ " set not null " ~ (constraint.expression or "") ~ ";" %} {% do statements.append(stmt) %} {% else %} {{ exceptions.warn('not_null constraint on invalid column: ' ~ column_name) }} @@ -170,7 +170,7 @@ {{ exceptions.raise_compiler_error("Constraint of type " ~ type ~ " with no `name` provided, and no md5 utility.") }} {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " primary key(" ~ joined_names ~ ");" %} + {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " primary key(" ~ joined_names ~ ");" %} {% do statements.append(stmt) %} {% elif type == 'foreign_key' %} @@ -191,7 +191,7 @@ {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key" ~ constraint.get('expression') %} + {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " foreign key" ~ constraint.get('expression') %} {% else %} {% set column_names = constraint.get('columns', []) %} {% if column and not column_names %} @@ -227,7 +227,7 @@ {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key(" ~ joined_names ~ ") references " ~ parent %} + {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " foreign key(" ~ joined_names ~ ") references " ~ parent %} {% set parent_columns = constraint.get('to_columns') %} {% if parent_columns %} {% set stmt = stmt ~ "(" ~ parent_columns|join(", ") ~ ")"%} @@ -251,7 +251,7 @@ {{ exceptions.raise_compiler_error("Constraint of type " ~ type ~ " with no `name` provided, and no md5 utility.") }} {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " " ~ expression ~ ";" %} + {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " " ~ expression ~ ";" %} {% do statements.append(stmt) %} {% elif constraint.get('warn_unsupported') %} {{ exceptions.warn("unsupported constraint type: " ~ constraint.type)}} diff --git a/dbt/include/databricks/macros/relations/liquid_clustering.sql b/dbt/include/databricks/macros/relations/liquid_clustering.sql index 3cf81048..b30269fd 100644 --- a/dbt/include/databricks/macros/relations/liquid_clustering.sql +++ b/dbt/include/databricks/macros/relations/liquid_clustering.sql @@ -15,7 +15,7 @@ {%- set cols = [cols] -%} {%- endif -%} {%- call statement('set_cluster_by_columns') -%} - ALTER {{ target_relation.type }} {{ target_relation }} CLUSTER BY ({{ cols | join(', ') }}) + ALTER {{ target_relation.type }} {{ target_relation.render() }} CLUSTER BY ({{ cols | join(', ') }}) {%- endcall -%} {%- endif %} {%- endmacro -%} \ No newline at end of file diff --git a/dbt/include/databricks/macros/relations/materialized_view/alter.sql b/dbt/include/databricks/macros/relations/materialized_view/alter.sql index 41d9bed0..d406508d 100644 --- a/dbt/include/databricks/macros/relations/materialized_view/alter.sql +++ b/dbt/include/databricks/macros/relations/materialized_view/alter.sql @@ -46,6 +46,6 @@ {% macro get_alter_mv_internal(relation, configuration_changes) %} {%- set refresh = configuration_changes.changes["refresh"] -%} -- Currently only schedule can be altered - ALTER MATERIALIZED VIEW {{ relation }} + ALTER MATERIALIZED VIEW {{ relation.render() }} {{ get_alter_sql_refresh_schedule(refresh.cron, refresh.time_zone_value, refresh.is_altered) -}} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/materialized_view/drop.sql b/dbt/include/databricks/macros/relations/materialized_view/drop.sql index f3774119..4def4744 100644 --- a/dbt/include/databricks/macros/relations/materialized_view/drop.sql +++ b/dbt/include/databricks/macros/relations/materialized_view/drop.sql @@ -1,3 +1,3 @@ {% macro databricks__drop_materialized_view(relation) -%} - drop materialized view if exists {{ relation }} + drop materialized view if exists {{ relation.render() }} {%- endmacro %} diff --git a/dbt/include/databricks/macros/relations/materialized_view/refresh.sql b/dbt/include/databricks/macros/relations/materialized_view/refresh.sql index 10a8346b..9967eb21 100644 --- a/dbt/include/databricks/macros/relations/materialized_view/refresh.sql +++ b/dbt/include/databricks/macros/relations/materialized_view/refresh.sql @@ -1,3 +1,3 @@ {% macro databricks__refresh_materialized_view(relation) -%} - refresh materialized view {{ relation }} + refresh materialized view {{ relation.render() }} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/streaming_table/drop.sql b/dbt/include/databricks/macros/relations/streaming_table/drop.sql index c8e0cd83..1cfc246a 100644 --- a/dbt/include/databricks/macros/relations/streaming_table/drop.sql +++ b/dbt/include/databricks/macros/relations/streaming_table/drop.sql @@ -3,5 +3,5 @@ {%- endmacro %} {% macro default__drop_streaming_table(relation) -%} - drop table if exists {{ relation }} + drop table if exists {{ relation.render() }} {%- endmacro %} diff --git a/dbt/include/databricks/macros/relations/streaming_table/refresh.sql b/dbt/include/databricks/macros/relations/streaming_table/refresh.sql index 66b86f1f..94c96d5c 100644 --- a/dbt/include/databricks/macros/relations/streaming_table/refresh.sql +++ b/dbt/include/databricks/macros/relations/streaming_table/refresh.sql @@ -3,7 +3,7 @@ {%- endmacro %} {% macro databricks__refresh_streaming_table(relation, sql) -%} - create or refresh streaming table {{ relation }} + create or refresh streaming table {{ relation.render() }} as {{ sql }} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/table/create.sql b/dbt/include/databricks/macros/relations/table/create.sql index 9e74d57d..b2aba2fe 100644 --- a/dbt/include/databricks/macros/relations/table/create.sql +++ b/dbt/include/databricks/macros/relations/table/create.sql @@ -5,9 +5,9 @@ {%- else -%} {%- set file_format = config.get('file_format', default='delta') -%} {% if file_format == 'delta' %} - create or replace table {{ relation }} + create or replace table {{ relation.render() }} {% else %} - create table {{ relation }} + create table {{ relation.render() }} {% endif %} {%- set contract_config = config.get('contract') -%} {% if contract_config and contract_config.enforced %} diff --git a/dbt/include/databricks/macros/relations/table/drop.sql b/dbt/include/databricks/macros/relations/table/drop.sql index 3a7d0ced..7bce7cf4 100644 --- a/dbt/include/databricks/macros/relations/table/drop.sql +++ b/dbt/include/databricks/macros/relations/table/drop.sql @@ -1,3 +1,3 @@ {% macro databricks__drop_table(relation) -%} - drop table if exists {{ relation }} + drop table if exists {{ relation.render() }} {%- endmacro %} diff --git a/dbt/include/databricks/macros/relations/tags.sql b/dbt/include/databricks/macros/relations/tags.sql index 3467631d..fb39c378 100644 --- a/dbt/include/databricks/macros/relations/tags.sql +++ b/dbt/include/databricks/macros/relations/tags.sql @@ -33,7 +33,7 @@ {%- endmacro -%} {% macro alter_set_tags(relation, tags) -%} - ALTER {{ relation.type }} {{ relation }} SET TAGS ( + ALTER {{ relation.type }} {{ relation.render() }} SET TAGS ( {% for tag in tags -%} '{{ tag }}' = '{{ tags[tag] }}' {%- if not loop.last %}, {% endif -%} {%- endfor %} @@ -41,7 +41,7 @@ {%- endmacro -%} {% macro alter_unset_tags(relation, tags) -%} - ALTER {{ relation.type }} {{ relation }} UNSET TAGS ( + ALTER {{ relation.type }} {{ relation.render() }} UNSET TAGS ( {% for tag in tags -%} '{{ tag }}' {%- if not loop.last %}, {%- endif %} {%- endfor %} diff --git a/dbt/include/databricks/macros/relations/tblproperties.sql b/dbt/include/databricks/macros/relations/tblproperties.sql index 34b6488f..b11fd7b5 100644 --- a/dbt/include/databricks/macros/relations/tblproperties.sql +++ b/dbt/include/databricks/macros/relations/tblproperties.sql @@ -17,7 +17,7 @@ {% set tblproperty_statment = databricks__tblproperties_clause(tblproperties) %} {% if tblproperty_statment %} {%- call statement('apply_tblproperties') -%} - ALTER {{ relation.type }} {{ relation }} SET {{ tblproperty_statment}} + ALTER {{ relation.type }} {{ relation.render() }} SET {{ tblproperty_statment}} {%- endcall -%} {% endif %} {%- endmacro -%} diff --git a/dbt/include/databricks/macros/relations/view/create.sql b/dbt/include/databricks/macros/relations/view/create.sql index 096e12de..5399b4ef 100644 --- a/dbt/include/databricks/macros/relations/view/create.sql +++ b/dbt/include/databricks/macros/relations/view/create.sql @@ -1,5 +1,5 @@ {% macro databricks__create_view_as(relation, sql) -%} - create or replace view {{ relation }} + create or replace view {{ relation.render() }} {% if config.persist_column_docs() -%} {% set model_columns = model.columns %} {% set query_columns = get_columns_in_query(sql) %} diff --git a/dbt/include/databricks/macros/relations/view/drop.sql b/dbt/include/databricks/macros/relations/view/drop.sql index aa199d76..9098c925 100644 --- a/dbt/include/databricks/macros/relations/view/drop.sql +++ b/dbt/include/databricks/macros/relations/view/drop.sql @@ -1,3 +1,3 @@ {% macro databricks__drop_view(relation) -%} - drop view if exists {{ relation }} + drop view if exists {{ relation.render() }} {%- endmacro %} From 67bf98d0fa7d24a2affc4d2311c06827d5fba8ae Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Fri, 17 Jan 2025 10:38:15 -0800 Subject: [PATCH 06/21] Adding missing 1.9 Snapshot behavior (#904) --- CHANGELOG.md | 4 + .../macros/materializations/snapshot.sql | 88 ++--- .../simple_snapshot/test_new_record_mode.py | 74 ++++ .../simple_snapshot/test_various_configs.py | 345 ++++++++++++++++++ 4 files changed, 459 insertions(+), 52 deletions(-) create mode 100644 tests/functional/adapter/simple_snapshot/test_new_record_mode.py create mode 100644 tests/functional/adapter/simple_snapshot/test_various_configs.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c01efeeb..54f17963 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## dbt-databricks 1.9.2 (TBD) +### Features + +- Update snapshot materialization to support new snapshot features ([904](https://github.com/databricks/dbt-databricks/pull/904)) + ### Under the Hood - Refactor global state reading ([888](https://github.com/databricks/dbt-databricks/pull/888)) diff --git a/dbt/include/databricks/macros/materializations/snapshot.sql b/dbt/include/databricks/macros/materializations/snapshot.sql index 3d1236a1..3a513a24 100644 --- a/dbt/include/databricks/macros/materializations/snapshot.sql +++ b/dbt/include/databricks/macros/materializations/snapshot.sql @@ -1,27 +1,4 @@ -{% macro databricks_build_snapshot_staging_table(strategy, sql, target_relation) %} - {% set tmp_identifier = target_relation.identifier ~ '__dbt_tmp' %} - - {%- set tmp_relation = api.Relation.create(identifier=tmp_identifier, - schema=target_relation.schema, - database=target_relation.database, - type='view') -%} - - {% set select = snapshot_staging_table(strategy, sql, target_relation) %} - - {# needs to be a non-temp view so that its columns can be ascertained via `describe` #} - {% call statement('build_snapshot_staging_relation') %} - create or replace view {{ tmp_relation }} - as - {{ select }} - {% endcall %} - - {% do return(tmp_relation) %} -{% endmacro %} - - {% materialization snapshot, adapter='databricks' %} - {%- set config = model['config'] -%} - {%- set target_table = model.get('alias', model.get('name')) -%} {%- set strategy_name = config.get('strategy') -%} @@ -62,47 +39,43 @@ {{ run_hooks(pre_hooks, inside_transaction=True) }} {% set strategy_macro = strategy_dispatch(strategy_name) %} - {% set strategy = strategy_macro(model, "snapshotted_data", "source_data", config, target_relation_exists) %} + {% set strategy = strategy_macro(model, "snapshotted_data", "source_data", model['config'], target_relation_exists) %} {% if not target_relation_exists %} {% set build_sql = build_snapshot_table(strategy, model['compiled_code']) %} + {% set build_or_select_sql = build_sql %} {% set final_sql = create_table_as(False, target_relation, build_sql) %} - {% call statement('main') %} - {{ final_sql }} - {% endcall %} - - {% do persist_docs(target_relation, model, for_relation=False) %} - {% else %} - {{ adapter.valid_snapshot_target(target_relation) }} + {% set columns = config.get("snapshot_table_column_names") or get_snapshot_table_column_names() %} - {% if target_relation.database is none %} - {% set staging_table = spark_build_snapshot_staging_table(strategy, sql, target_relation) %} - {% else %} - {% set staging_table = databricks_build_snapshot_staging_table(strategy, sql, target_relation) %} - {% endif %} + {{ adapter.assert_valid_snapshot_target_given_strategy(target_relation, columns, strategy) }} + + {% set build_or_select_sql = snapshot_staging_table(strategy, sql, target_relation) %} + {% set staging_table = build_snapshot_staging_table(strategy, sql, target_relation) %} -- this may no-op if the database does not require column expansion {% do adapter.expand_target_column_types(from_relation=staging_table, to_relation=target_relation) %} + {% set remove_columns = ['dbt_change_type', 'DBT_CHANGE_TYPE', 'dbt_unique_key', 'DBT_UNIQUE_KEY'] %} + {% if unique_key | is_list %} + {% for key in strategy.unique_key %} + {{ remove_columns.append('dbt_unique_key_' + loop.index|string) }} + {{ remove_columns.append('DBT_UNIQUE_KEY_' + loop.index|string) }} + {% endfor %} + {% endif %} + {% set missing_columns = adapter.get_missing_columns(staging_table, target_relation) - | rejectattr('name', 'equalto', 'dbt_change_type') - | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') - | rejectattr('name', 'equalto', 'dbt_unique_key') - | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | rejectattr('name', 'in', remove_columns) | list %} {% do create_columns(target_relation, missing_columns) %} {% set source_columns = adapter.get_columns_in_relation(staging_table) - | rejectattr('name', 'equalto', 'dbt_change_type') - | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') - | rejectattr('name', 'equalto', 'dbt_unique_key') - | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | rejectattr('name', 'in', remove_columns) | list %} {% set quoted_source_columns = [] %} @@ -117,23 +90,34 @@ ) %} - {% call statement_with_staging_table('main', staging_table) %} - {{ final_sql }} - {% endcall %} + {% endif %} - {% do persist_docs(target_relation, model, for_relation=True) %} - {% endif %} + {{ check_time_data_types(build_or_select_sql) }} - {% set should_revoke = should_revoke(target_relation_exists, full_refresh_mode) %} - {% do apply_grants(target_relation, grant_config, should_revoke) %} + {% call statement('main') %} + {{ final_sql }} + {% endcall %} - {% do persist_constraints(target_relation, model) %} + {% set should_revoke = should_revoke(target_relation_exists, full_refresh_mode=False) %} + {% do apply_grants(target_relation, grant_config, should_revoke=should_revoke) %} + + {% do persist_docs(target_relation, model) %} + + {% if not target_relation_exists %} + {% do create_indexes(target_relation) %} + {% endif %} {{ run_hooks(post_hooks, inside_transaction=True) }} {{ adapter.commit() }} + {% if staging_table is defined %} + {% do post_snapshot(staging_table) %} + {% endif %} + + {% do persist_constraints(target_relation, model) %} + {{ run_hooks(post_hooks, inside_transaction=False) }} {{ return({'relations': [target_relation]}) }} diff --git a/tests/functional/adapter/simple_snapshot/test_new_record_mode.py b/tests/functional/adapter/simple_snapshot/test_new_record_mode.py new file mode 100644 index 00000000..6b436a31 --- /dev/null +++ b/tests/functional/adapter/simple_snapshot/test_new_record_mode.py @@ -0,0 +1,74 @@ +import pytest + +from dbt.tests.adapter.simple_snapshot.new_record_mode import ( + _delete_sql, + _invalidate_sql, + _ref_snapshot_sql, + _seed_new_record_mode, + _snapshot_actual_sql, + _snapshots_yml, + _update_sql, +) +from dbt.tests.util import check_relations_equal, run_dbt + + +class TestDatabricksSnapshotNewRecordMode: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": _snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": _snapshots_yml, + "ref_snapshot.sql": _ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def seed_new_record_mode(self): + return _seed_new_record_mode + + @pytest.fixture(scope="class") + def invalidate_sql_1(self): + return _invalidate_sql.split(";", 1)[0].replace("BEGIN", "") + + @pytest.fixture(scope="class") + def invalidate_sql_2(self): + return _invalidate_sql.split(";", 1)[1].replace("END", "").replace(";", "") + + @pytest.fixture(scope="class") + def update_sql(self): + return _update_sql.replace("text", "string") + + @pytest.fixture(scope="class") + def delete_sql(self): + return _delete_sql + + def test_snapshot_new_record_mode( + self, project, seed_new_record_mode, invalidate_sql_1, invalidate_sql_2, update_sql + ): + for sql in ( + seed_new_record_mode.replace("text", "string") + .replace("TEXT", "STRING") + .replace("BEGIN", "") + .replace("END;", "") + .replace(" WITHOUT TIME ZONE", "") + .split(";") + ): + project.run_sql(sql) + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + project.run_sql(_delete_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 diff --git a/tests/functional/adapter/simple_snapshot/test_various_configs.py b/tests/functional/adapter/simple_snapshot/test_various_configs.py new file mode 100644 index 00000000..18b82de0 --- /dev/null +++ b/tests/functional/adapter/simple_snapshot/test_various_configs.py @@ -0,0 +1,345 @@ +import datetime + +import pytest +from agate import Table + +from dbt.tests.adapter.simple_snapshot.fixtures import ( + create_multi_key_seed_sql, + create_multi_key_snapshot_expected_sql, + create_seed_sql, + create_snapshot_expected_sql, + model_seed_sql, + populate_multi_key_snapshot_expected_sql, + populate_snapshot_expected_sql, + populate_snapshot_expected_valid_to_current_sql, + ref_snapshot_sql, + seed_insert_sql, + seed_multi_key_insert_sql, + snapshot_actual_sql, + snapshots_multi_key_yml, + snapshots_no_column_names_yml, + snapshots_valid_to_current_yml, + snapshots_yml, + update_multi_key_sql, + update_sql, + update_with_current_sql, +) +from dbt.tests.util import ( + check_relations_equal, + get_manifest, + run_dbt, + run_dbt_and_capture, + run_sql_with_adapter, + update_config_file, +) + + +def text_replace(input: str) -> str: + return input.replace("TEXT", "STRING").replace("text", "string") + + +create_snapshot_expected_sql = text_replace(create_snapshot_expected_sql) +populate_snapshot_expected_sql = text_replace(populate_snapshot_expected_sql) +populate_snapshot_expected_valid_to_current_sql = text_replace( + populate_snapshot_expected_valid_to_current_sql +) +update_with_current_sql = text_replace(update_with_current_sql) +create_multi_key_snapshot_expected_sql = text_replace(create_multi_key_snapshot_expected_sql) +populate_multi_key_snapshot_expected_sql = text_replace(populate_multi_key_snapshot_expected_sql) +update_sql = text_replace(update_sql) +update_multi_key_sql = text_replace(update_multi_key_sql) + +invalidate_sql_1 = """ +-- update records 11 - 21. Change email and updated_at field +update {schema}.seed set + updated_at = updated_at + interval '1 hour', + email = case when id = 20 then 'pfoxj@creativecommons.org' else 'new_' || email end +where id >= 10 and id <= 20 +""" + +invalidate_sql_2 = """ +-- invalidate records 11 - 21 +update {schema}.snapshot_expected set + test_valid_to = updated_at + interval '1 hour' +where id >= 10 and id <= 20; +""" + +invalidate_multi_key_sql_1 = """ +-- update records 11 - 21. Change email and updated_at field +update {schema}.seed set + updated_at = updated_at + interval '1 hour', + email = case when id1 = 20 then 'pfoxj@creativecommons.org' else 'new_' || email end +where id1 >= 10 and id1 <= 20; +""" + +invalidate_multi_key_sql_2 = """ +-- invalidate records 11 - 21 +update {schema}.snapshot_expected set + test_valid_to = updated_at + interval '1 hour' +where id1 >= 10 and id1 <= 20; +""" + + +class BaseSnapshotColumnNames: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_snapshot_column_names(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class BaseSnapshotColumnNamesFromDbtProject: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_no_column_names_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + } + } + } + + def test_snapshot_column_names_from_project(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class BaseSnapshotInvalidColumnNames: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_no_column_names_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + } + } + } + + def test_snapshot_invalid_column_names(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + manifest = get_manifest(project.project_root) + snapshot_node = manifest.nodes["snapshot.test.snapshot_actual"] + snapshot_node.config.snapshot_meta_column_names == { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + # Change snapshot_meta_columns and look for an error + different_columns = { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_updated_at": "test_updated_at", + } + } + } + } + update_config_file(different_columns, "dbt_project.yml") + + results, log_output = run_dbt_and_capture(["snapshot"], expect_pass=False) + assert len(results) == 1 + assert "Compilation Error in snapshot snapshot_actual" in log_output + assert "Snapshot target is missing configured columns" in log_output + + +class BaseSnapshotDbtValidToCurrent: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_valid_to_current_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_valid_to_current(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_valid_to_current_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + original_snapshot: Table = run_sql_with_adapter( + project.adapter, + "select id, test_scd_id, test_valid_to from {schema}.snapshot_actual", + "all", + ) + assert original_snapshot[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + original_row = list( + filter(lambda x: x[1] == "61ecd07d17b8a4acb57d115eebb0e2c9", original_snapshot) + ) + assert original_row[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_with_current_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + updated_snapshot: Table = run_sql_with_adapter( + project.adapter, + "select id, test_scd_id, test_valid_to from {schema}.snapshot_actual", + "all", + ) + print(updated_snapshot) + assert updated_snapshot[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + # Original row that was updated now has a non-current (2099/12/31) date + original_row = list( + filter(lambda x: x[1] == "61ecd07d17b8a4acb57d115eebb0e2c9", updated_snapshot) + ) + assert original_row[0][2] == datetime.datetime( + 2016, 8, 20, 16, 44, 49, tzinfo=datetime.timezone.utc + ) + updated_row = list( + filter(lambda x: x[1] == "af1f803f2179869aeacb1bfe2b23c1df", updated_snapshot) + ) + + # Updated row has a current date + assert updated_row[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +# This uses snapshot_meta_column_names, yaml-only snapshot def, +# and multiple keys +class BaseSnapshotMultiUniqueKey: + @pytest.fixture(scope="class") + def models(self): + return { + "seed.sql": model_seed_sql, + "snapshots.yml": snapshots_multi_key_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_multi_column_unique_key(self, project): + project.run_sql(create_multi_key_seed_sql) + project.run_sql(create_multi_key_snapshot_expected_sql) + project.run_sql(seed_multi_key_insert_sql) + project.run_sql(populate_multi_key_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_multi_key_sql_1) + project.run_sql(invalidate_multi_key_sql_2) + project.run_sql(update_multi_key_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class TestDatabricksSnapshotColumnNames(BaseSnapshotColumnNames): + pass + + +class TestDatabricksSnapshotColumnNamesFromDbtProject(BaseSnapshotColumnNamesFromDbtProject): + pass + + +class TestDatabricksSnapshotInvalidColumnNames(BaseSnapshotInvalidColumnNames): + pass + + +class TestDatabricksSnapshotDbtValidToCurrent(BaseSnapshotDbtValidToCurrent): + pass + + +class TestDatabricksSnapshotMultiUniqueKey(BaseSnapshotMultiUniqueKey): + pass From 45ec259ff093dc9b0474b2d1882e56f85627465e Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Fri, 17 Jan 2025 12:25:03 -0800 Subject: [PATCH 07/21] Enforce retry defaults to ensure sufficient retries regardless of PySQL (#907) --- CHANGELOG.md | 1 + dbt/adapters/databricks/credentials.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 54f17963..11428fad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - Refactor global state reading ([888](https://github.com/databricks/dbt-databricks/pull/888)) - Switch to relation.render() for string interpolation ([903](https://github.com/databricks/dbt-databricks/pull/903)) +- Ensure retry defaults for PySQL ([907](https://github.com/databricks/dbt-databricks/pull/907)) ## dbt-databricks 1.9.1 (December 16, 2024) diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 387d0e76..250e79f6 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -74,8 +74,10 @@ class DatabricksCredentials(Credentials): @classmethod def __pre_deserialize__(cls, data: dict[Any, Any]) -> dict[Any, Any]: data = super().__pre_deserialize__(data) - if "database" not in data: - data["database"] = None + data.setdefault("database", None) + data.setdefault("connection_parameters", {}) + data["connection_parameters"].setdefault("_retry_stop_after_attempts_count", 30) + data["connection_parameters"].setdefault("_retry_delay_max", 60) return data def __post_init__(self) -> None: From 6974edb95a7aa901cc335f8c07531f8e0c710399 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Tue, 21 Jan 2025 09:12:31 -0800 Subject: [PATCH 08/21] Prep for 1.9.2 release (#908) --- CHANGELOG.md | 2 +- dbt/adapters/databricks/__version__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 11428fad..20511cb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## dbt-databricks 1.9.2 (TBD) +## dbt-databricks 1.9.2 (Jan 21, 2024) ### Features diff --git a/dbt/adapters/databricks/__version__.py b/dbt/adapters/databricks/__version__.py index 70227976..1b022739 100644 --- a/dbt/adapters/databricks/__version__.py +++ b/dbt/adapters/databricks/__version__.py @@ -1 +1 @@ -version = "1.9.1" +version = "1.9.2" From 395801ec7810c8c8dd7d71e461a66cdda35c5e18 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Thu, 23 Jan 2025 09:23:57 -0800 Subject: [PATCH 09/21] Compress to one connection manager (#910) --- CHANGELOG.md | 6 + dbt/adapters/databricks/connections.py | 386 ++++++++++-------------- dbt/adapters/databricks/global_state.py | 10 - dbt/adapters/databricks/impl.py | 10 +- 4 files changed, 162 insertions(+), 250 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20511cb1..3e72506e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## dbt-databricks 1.9.3 (TBD) + +### Under the Hood + +- Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) + ## dbt-databricks 1.9.2 (Jan 21, 2024) ### Features diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 0b523574..3a6b4817 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -1,5 +1,4 @@ import decimal -import os import re import sys import time @@ -10,7 +9,6 @@ from dataclasses import dataclass from multiprocessing.context import SpawnContext from numbers import Number -from threading import get_ident from typing import TYPE_CHECKING, Any, Optional, cast from dbt_common.events.contextvars import get_node_info @@ -59,7 +57,6 @@ CursorCreate, ) from dbt.adapters.databricks.events.other_events import QueryError -from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker from dbt.adapters.databricks.utils import redact_credentials @@ -383,6 +380,9 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): super().__init__(profile, mp_context) creds = cast(DatabricksCredentials, self.profile.credentials) self.api_client = DatabricksApiClient.create(creds, 15 * 60) + self.threads_compute_connections: dict[ + Hashable, dict[Hashable, DatabricksDBTConnection] + ] = {} def cancel_open(self) -> list[str]: cancelled = super().cancel_open() @@ -431,39 +431,19 @@ def set_connection_name( 'connection_named', called by 'connection_for(node)'. Creates a connection for this thread if one doesn't already exist, and will rename an existing connection.""" + self._cleanup_idle_connections() conn_name: str = "master" if name is None else name # Get a connection for this thread - conn = self.get_if_exists() - - if conn and conn.name == conn_name and conn.state == ConnectionState.OPEN: - # Found a connection and nothing to do, so just return it - return conn + conn = self._get_if_exists_compute_connection(_get_compute_name(query_header_context) or "") if conn is None: - # Create a new connection - conn = DatabricksDBTConnection( - type=Identifier(self.TYPE), - name=conn_name, - state=ConnectionState.INIT, - transaction_open=False, - handle=None, - credentials=self.profile.credentials, - ) - conn.handle = LazyHandle(self.get_open_for_context(query_header_context)) - # Add the connection to thread_connections for this thread - self.set_thread_connection(conn) - fire_event( - NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) - ) + conn = self._create_compute_connection(conn_name, query_header_context) else: # existing connection either wasn't open or didn't have the right name - if conn.state != ConnectionState.OPEN: - conn.handle = LazyHandle(self.get_open_for_context(query_header_context)) - if conn.name != conn_name: - orig_conn_name: str = conn.name or "" - conn.name = conn_name - fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name)) + conn = self._update_compute_connection(conn, conn_name) + + conn._acquire(query_header_context) return conn @@ -601,6 +581,34 @@ def list_tables(self, database: str, schema: str, identifier: Optional[str] = No ), ) + # override + def release(self) -> None: + with self.lock: + conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + if conn is None: + return + + conn._release() + + # override + def cleanup_all(self) -> None: + with self.lock: + for thread_connections in self.threads_compute_connections.values(): + for connection in thread_connections.values(): + if connection.acquire_release_count > 0: + fire_event( + ConnectionLeftOpenInCleanup(conn_name=cast_to_str(connection.name)) + ) + else: + fire_event( + ConnectionClosedInCleanup(conn_name=cast_to_str(connection.name)) + ) + self.close(connection) + + # garbage collect these connections + self.thread_connections.clear() + self.threads_compute_connections.clear() + @classmethod def get_open_for_context( cls, query_header_context: Any = None @@ -617,13 +625,8 @@ def open_for_model(connection: Connection) -> Connection: @classmethod def open(cls, connection: Connection) -> Connection: - # Simply call _open with no ResultNode argument. - # Because this is an overridden method we can't just add - # a ResultNode parameter to open. - return cls._open(connection) + databricks_connection = cast(DatabricksDBTConnection, connection) - @classmethod - def _open(cls, connection: Connection, query_header_context: Any = None) -> Connection: if connection.state == ConnectionState.OPEN: return connection @@ -646,12 +649,12 @@ def _open(cls, connection: Connection, query_header_context: Any = None) -> Conn # If a model specifies a compute resource the http path # may be different than the http_path property of creds. - http_path = _get_http_path(query_header_context, creds) + http_path = databricks_connection.http_path def connect() -> DatabricksSQLConnectionWrapper: try: # TODO: what is the error when a user specifies a catalog they don't have access to - conn: DatabricksSQLConnection = dbsql.connect( + conn = dbsql.connect( server_hostname=creds.host, http_path=http_path, credentials_provider=cls.credentials_provider, @@ -663,7 +666,11 @@ def connect() -> DatabricksSQLConnectionWrapper: _user_agent_entry=user_agent_entry, **connection_parameters, ) - logger.debug(ConnectionCreated(str(conn))) + + if conn: + databricks_connection.session_id = conn.get_session_id_hex() + databricks_connection.last_used_time = time.time() + logger.debug(ConnectionCreated(str(databricks_connection))) return DatabricksSQLConnectionWrapper( conn, @@ -693,58 +700,74 @@ def exponential_backoff(attempt: int) -> int: ) @classmethod - def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse: - _query_id = getattr(cursor, "hex_query_id", None) - if cursor is None: - logger.debug("No cursor was provided. Query ID not available.") - query_id = "N/A" - else: - query_id = _query_id - message = "OK" - return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore + def _open(cls, connection: Connection, query_header_context: Any = None) -> Connection: + if connection.state == ConnectionState.OPEN: + return connection + creds: DatabricksCredentials = connection.credentials + timeout = creds.connect_timeout -class ExtendedSessionConnectionManager(DatabricksConnectionManager): - def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None: - assert ( - GlobalState.get_use_long_sessions() - ), "This connection manager should only be used when USE_LONG_SESSIONS is enabled" - super().__init__(profile, mp_context) - self.threads_compute_connections: dict[ - Hashable, dict[Hashable, DatabricksDBTConnection] - ] = {} + # gotta keep this so we don't prompt users many times + cls.credentials_provider = creds.authenticate(cls.credentials_provider) - def set_connection_name( - self, name: Optional[str] = None, query_header_context: Any = None - ) -> Connection: - """Called by 'acquire_connection' in DatabricksAdapter, which is called by - 'connection_named', called by 'connection_for(node)'. - Creates a connection for this thread if one doesn't already - exist, and will rename an existing connection.""" - self._cleanup_idle_connections() + invocation_env = creds.get_invocation_env() + user_agent_entry = cls._user_agent + if invocation_env: + user_agent_entry = f"{cls._user_agent}; {invocation_env}" - conn_name: str = "master" if name is None else name + connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] - # Get a connection for this thread - conn = self._get_if_exists_compute_connection(_get_compute_name(query_header_context) or "") + http_headers: list[tuple[str, str]] = list( + creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() + ) - if conn is None: - conn = self._create_compute_connection(conn_name, query_header_context) - else: # existing connection either wasn't open or didn't have the right name - conn = self._update_compute_connection(conn, conn_name) + # If a model specifies a compute resource the http path + # may be different than the http_path property of creds. + http_path = _get_http_path(query_header_context, creds) - conn._acquire(query_header_context) + def connect() -> DatabricksSQLConnectionWrapper: + try: + # TODO: what is the error when a user specifies a catalog they don't have access to + conn: DatabricksSQLConnection = dbsql.connect( + server_hostname=creds.host, + http_path=http_path, + credentials_provider=cls.credentials_provider, + http_headers=http_headers if http_headers else None, + session_configuration=creds.session_properties, + catalog=creds.database, + use_inline_params="silent", + # schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL. + _user_agent_entry=user_agent_entry, + **connection_parameters, + ) + logger.debug(ConnectionCreated(str(conn))) - return conn + return DatabricksSQLConnectionWrapper( + conn, + is_cluster=creds.cluster_id is not None, + creds=creds, + user_agent=user_agent_entry, + ) + except Error as exc: + logger.error(ConnectionCreateError(exc)) + raise - # override - def release(self) -> None: - with self.lock: - conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) - if conn is None: - return + def exponential_backoff(attempt: int) -> int: + return attempt * attempt - conn._release() + retryable_exceptions = [] + # this option is for backwards compatibility + if creds.retry_all: + retryable_exceptions = [Error] + + return cls.retry_connection( + connection, + connect=connect, + logger=logger, + retryable_exceptions=retryable_exceptions, + retry_limit=creds.connect_retries, + retry_timeout=(timeout if timeout is not None else exponential_backoff), + ) # override @classmethod @@ -756,46 +779,22 @@ def close(cls, connection: Connection) -> Connection: connection.state = ConnectionState.CLOSED return connection - # override - def cleanup_all(self) -> None: - with self.lock: - for thread_connections in self.threads_compute_connections.values(): - for connection in thread_connections.values(): - if connection.acquire_release_count > 0: - fire_event( - ConnectionLeftOpenInCleanup(conn_name=cast_to_str(connection.name)) - ) - else: - fire_event( - ConnectionClosedInCleanup(conn_name=cast_to_str(connection.name)) - ) - self.close(connection) - - # garbage collect these connections - self.thread_connections.clear() - self.threads_compute_connections.clear() - - def _update_compute_connection( - self, conn: DatabricksDBTConnection, new_name: str - ) -> DatabricksDBTConnection: - if conn.name == new_name and conn.state == ConnectionState.OPEN: - # Found a connection and nothing to do, so just return it - return conn - - orig_conn_name: str = conn.name or "" - - if conn.state != ConnectionState.OPEN: - conn.handle = LazyHandle(self.open) - if conn.name != new_name: - conn.name = new_name - fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=new_name)) - - current_thread_conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) - if current_thread_conn and current_thread_conn.compute_name != conn.compute_name: - self.clear_thread_connection() - self.set_thread_connection(conn) + @classmethod + def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse: + _query_id = getattr(cursor, "hex_query_id", None) + if cursor is None: + logger.debug("No cursor was provided. Query ID not available.") + query_id = "N/A" + else: + query_id = _query_id + message = "OK" + return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore - logger.debug(ConnectionReuse(str(conn), orig_conn_name)) + def get_thread_connection(self) -> Connection: + conn = super().get_thread_connection() + self._cleanup_idle_connections() + dbr_conn = cast(DatabricksDBTConnection, conn) + logger.debug(ConnectionRetrieve(str(dbr_conn))) return conn @@ -810,28 +809,6 @@ def _add_compute_connection(self, conn: DatabricksDBTConnection) -> None: ) thread_map[conn.compute_name] = conn - def _get_compute_connections( - self, - ) -> dict[Hashable, DatabricksDBTConnection]: - """Retrieve a map of compute name to connection for the current thread.""" - - thread_id = self.get_thread_identifier() - with self.lock: - thread_map = self.threads_compute_connections.get(thread_id) - if not thread_map: - thread_map = {} - self.threads_compute_connections[thread_id] = thread_map - return thread_map - - def _get_if_exists_compute_connection( - self, compute_name: str - ) -> Optional[DatabricksDBTConnection]: - """Get the connection for the current thread and named compute, if it exists.""" - - with self.lock: - threads_map = self._get_compute_connections() - return threads_map.get(compute_name) - def _cleanup_idle_connections(self) -> None: with self.lock: # Get all connections associated with this thread. There can be multiple connections @@ -897,95 +874,51 @@ def _create_compute_connection( return conn - def get_thread_connection(self) -> Connection: - conn = super().get_thread_connection() - self._cleanup_idle_connections() - dbr_conn = cast(DatabricksDBTConnection, conn) - logger.debug(ConnectionRetrieve(str(dbr_conn))) - - return conn - - @classmethod - def open(cls, connection: Connection) -> Connection: - # Once long session management is no longer under the USE_LONG_SESSIONS toggle - # this should be renamed and replace the _open class method. - assert ( - GlobalState.get_use_long_sessions() - ), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS" - - databricks_connection = cast(DatabricksDBTConnection, connection) - - if connection.state == ConnectionState.OPEN: - return connection - - creds: DatabricksCredentials = connection.credentials - timeout = creds.connect_timeout - - # gotta keep this so we don't prompt users many times - cls.credentials_provider = creds.authenticate(cls.credentials_provider) - - invocation_env = creds.get_invocation_env() - user_agent_entry = cls._user_agent - if invocation_env: - user_agent_entry = f"{cls._user_agent}; {invocation_env}" + def _get_if_exists_compute_connection( + self, compute_name: str + ) -> Optional[DatabricksDBTConnection]: + """Get the connection for the current thread and named compute, if it exists.""" - connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] + with self.lock: + threads_map = self._get_compute_connections() + return threads_map.get(compute_name) - http_headers: list[tuple[str, str]] = list( - creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() - ) + def _get_compute_connections( + self, + ) -> dict[Hashable, DatabricksDBTConnection]: + """Retrieve a map of compute name to connection for the current thread.""" - # If a model specifies a compute resource the http path - # may be different than the http_path property of creds. - http_path = databricks_connection.http_path + thread_id = self.get_thread_identifier() + with self.lock: + thread_map = self.threads_compute_connections.get(thread_id) + if not thread_map: + thread_map = {} + self.threads_compute_connections[thread_id] = thread_map + return thread_map - def connect() -> DatabricksSQLConnectionWrapper: - try: - # TODO: what is the error when a user specifies a catalog they don't have access to - conn = dbsql.connect( - server_hostname=creds.host, - http_path=http_path, - credentials_provider=cls.credentials_provider, - http_headers=http_headers if http_headers else None, - session_configuration=creds.session_properties, - catalog=creds.database, - use_inline_params="silent", - # schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL. - _user_agent_entry=user_agent_entry, - **connection_parameters, - ) + def _update_compute_connection( + self, conn: DatabricksDBTConnection, new_name: str + ) -> DatabricksDBTConnection: + if conn.name == new_name and conn.state == ConnectionState.OPEN: + # Found a connection and nothing to do, so just return it + return conn - if conn: - databricks_connection.session_id = conn.get_session_id_hex() - databricks_connection.last_used_time = time.time() - logger.debug(ConnectionCreated(str(databricks_connection))) + orig_conn_name: str = conn.name or "" - return DatabricksSQLConnectionWrapper( - conn, - is_cluster=creds.cluster_id is not None, - creds=creds, - user_agent=user_agent_entry, - ) - except Error as exc: - logger.error(ConnectionCreateError(exc)) - raise + if conn.state != ConnectionState.OPEN: + conn.handle = LazyHandle(self.open) + if conn.name != new_name: + conn.name = new_name + fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=new_name)) - def exponential_backoff(attempt: int) -> int: - return attempt * attempt + current_thread_conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + if current_thread_conn and current_thread_conn.compute_name != conn.compute_name: + self.clear_thread_connection() + self.set_thread_connection(conn) - retryable_exceptions = [] - # this option is for backwards compatibility - if creds.retry_all: - retryable_exceptions = [Error] + logger.debug(ConnectionReuse(str(conn), orig_conn_name)) - return cls.retry_connection( - connection, - connect=connect, - logger=logger, - retryable_exceptions=retryable_exceptions, - retry_limit=creds.connect_retries, - retry_timeout=(timeout if timeout is not None else exponential_backoff), - ) + return conn def _get_compute_name(query_header_context: Any) -> Optional[str]: @@ -1005,24 +938,18 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O """Get the http_path for the compute specified for the node. If none is specified default will be used.""" - thread_id = (os.getpid(), get_ident()) - # ResultNode *should* have relation_name attr, but we work around a core # issue by checking. relation_name = getattr(query_header_context, "relation_name", "[unknown]") # If there is no node we return the http_path for the default compute. if not query_header_context: - if not GlobalState.get_use_long_sessions(): - logger.debug(f"Thread {thread_id}: using default compute resource.") return creds.http_path # Get the name of the compute resource specified in the node's config. # If none is specified return the http_path for the default compute. compute_name = _get_compute_name(query_header_context) if not compute_name: - if not GlobalState.get_use_long_sessions(): - logger.debug(f"On thread {thread_id}: {relation_name} using default compute resource.") return creds.http_path # Get the http_path for the named compute. @@ -1037,11 +964,6 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O f"does not specify http_path, relation: {relation_name}" ) - if not GlobalState.get_use_long_sessions(): - logger.debug( - f"On thread {thread_id}: {relation_name} using compute resource '{compute_name}'." - ) - return http_path diff --git a/dbt/adapters/databricks/global_state.py b/dbt/adapters/databricks/global_state.py index de240d39..cdc5df98 100644 --- a/dbt/adapters/databricks/global_state.py +++ b/dbt/adapters/databricks/global_state.py @@ -7,16 +7,6 @@ class GlobalState: single place than scattered throughout the codebase. """ - __use_long_sessions: ClassVar[Optional[bool]] = None - - @classmethod - def get_use_long_sessions(cls) -> bool: - if cls.__use_long_sessions is None: - cls.__use_long_sessions = ( - os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" - ) - return cls.__use_long_sessions - __invocation_env: ClassVar[Optional[str]] = None __invocation_env_set: ClassVar[bool] = False diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 15c333e2..d106dd1c 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -31,10 +31,7 @@ GetColumnsByInformationSchema, ) from dbt.adapters.databricks.column import DatabricksColumn -from dbt.adapters.databricks.connections import ( - DatabricksConnectionManager, - ExtendedSessionConnectionManager, -) +from dbt.adapters.databricks.connections import DatabricksConnectionManager from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.python_models.python_submissions import ( AllPurposeClusterPythonJobHelper, @@ -154,10 +151,7 @@ class DatabricksAdapter(SparkAdapter): Relation = DatabricksRelation Column = DatabricksColumn - if GlobalState.get_use_long_sessions(): - ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager - else: - ConnectionManager = DatabricksConnectionManager + ConnectionManager = DatabricksConnectionManager connections: DatabricksConnectionManager From 407cf240285e1212dbe697d84e97437fe430f77f Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Tue, 28 Jan 2025 08:38:09 -0800 Subject: [PATCH 10/21] Attempt to fix git paths (#920) --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d2f728d8..c176f1ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,10 +65,10 @@ check-sdist = [ [tool.hatch.envs.default] dependencies = [ "dbt_common @ git+https://github.com/dbt-labs/dbt-common.git", - "dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git@main", + "dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-adapters", "dbt-core @ git+https://github.com/dbt-labs/dbt-core.git@main#subdirectory=core", "dbt-tests-adapter @ git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-tests-adapter", - "dbt-spark @ git+https://github.com/dbt-labs/dbt-spark.git@main", + "dbt-spark @ git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-spark", "pytest", "pytest-xdist", "pytest-dotenv", From ebadd924ee9e0856c2397562a044ce0d8bd4e240 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerard=20Sol=C3=A0?= Date: Thu, 30 Jan 2025 00:09:06 +0100 Subject: [PATCH 11/21] Update impl.py to use POSIX standard on location (#919) --- CHANGELOG.md | 1 + dbt/adapters/databricks/impl.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e72506e..1a84cde1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ### Under the Hood - Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) +- Use POSIX standard when creating location for the tables ([919](https://github.com/databricks/dbt-databricks/pull/919)) ## dbt-databricks 1.9.2 (Jan 21, 2024) diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index d106dd1c..8adef638 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -1,4 +1,4 @@ -import os +import posixpath import re from abc import ABC, abstractmethod from collections import defaultdict @@ -214,9 +214,9 @@ def compute_external_path( raise DbtConfigError("location_root is required for external tables.") include_full_name_in_path = config.get("include_full_name_in_path", False) if include_full_name_in_path: - path = os.path.join(location_root, database, schema, identifier) + path = posixpath.join(location_root, database, schema, identifier) else: - path = os.path.join(location_root, identifier) + path = posixpath.join(location_root, identifier) if is_incremental: path = path + "_tmp" return path From ccb1fd74c2b0480ccfc75983570dc0bd608bf036 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Thu, 30 Jan 2025 12:08:35 -0800 Subject: [PATCH 12/21] Pin SQL Connector to < 3.7 (#923) --- CHANGELOG.md | 8 +++++++- pyproject.toml | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a84cde1..6113a08c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,16 @@ -## dbt-databricks 1.9.3 (TBD) +## dbt-databricks 1.9.4 (TBD) ### Under the Hood - Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) - Use POSIX standard when creating location for the tables ([919](https://github.com/databricks/dbt-databricks/pull/919)) +## dbt-databricks 1.9.3 (Jan 30, 2024) + +### Under the Hood + +- Pinned the python sql connector to 3.6.0 as a temporary measure while we investigate failure to wait for cluster start + ## dbt-databricks 1.9.2 (Jan 21, 2024) ### Features diff --git a/pyproject.toml b/pyproject.toml index c176f1ed..7c8c0e82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ ] dependencies = [ "databricks-sdk==0.17.0", - "databricks-sql-connector>=3.5.0, <4.0.0", + "databricks-sql-connector>=3.5.0, <3.7.0", "dbt-adapters>=1.7.0, <2.0", "dbt-common>=1.10.0, <2.0", "dbt-core>=1.8.7, <2.0", From fd9c6587bb6122e5b3614569bca2d499f531ace7 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Fri, 31 Jan 2025 09:25:54 -0800 Subject: [PATCH 13/21] Fixing Changelog (#924) --- CHANGELOG.md | 8 ++++++-- dbt/adapters/databricks/__version__.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6113a08c..dd6088f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,16 +1,20 @@ -## dbt-databricks 1.9.4 (TBD) +## dbt-databricks 1.9.5 (TBD) ### Under the Hood - Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) - Use POSIX standard when creating location for the tables ([919](https://github.com/databricks/dbt-databricks/pull/919)) -## dbt-databricks 1.9.3 (Jan 30, 2024) +## dbt-databricks 1.9.4 (Jan 30, 2024) ### Under the Hood - Pinned the python sql connector to 3.6.0 as a temporary measure while we investigate failure to wait for cluster start +## dbt-databricks 1.9.3 + +Yanked due to being published with the incorrect bits + ## dbt-databricks 1.9.2 (Jan 21, 2024) ### Features diff --git a/dbt/adapters/databricks/__version__.py b/dbt/adapters/databricks/__version__.py index 1b022739..53988968 100644 --- a/dbt/adapters/databricks/__version__.py +++ b/dbt/adapters/databricks/__version__.py @@ -1 +1 @@ -version = "1.9.2" +version = "1.9.4" From 75a71dd10fd1ed6b723cd6ca87720d6c0f035d24 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Wed, 5 Feb 2025 11:53:33 -0800 Subject: [PATCH 14/21] Clean up cursor management (#912) --- CHANGELOG.md | 1 + dbt/adapters/databricks/connections.py | 407 ++---------------- .../databricks/events/connection_events.py | 24 -- .../databricks/events/cursor_events.py | 59 --- dbt/adapters/databricks/global_state.py | 4 + dbt/adapters/databricks/handle.py | 324 ++++++++++++++ dbt/adapters/databricks/impl.py | 7 +- dbt/adapters/databricks/utils.py | 13 +- tests/unit/events/test_connection_events.py | 11 - tests/unit/events/test_cursor_events.py | 41 -- tests/unit/test_adapter.py | 18 +- tests/unit/test_handle.py | 211 +++++++++ tests/unit/test_idle_config.py | 6 +- 13 files changed, 613 insertions(+), 513 deletions(-) delete mode 100644 dbt/adapters/databricks/events/cursor_events.py create mode 100644 dbt/adapters/databricks/handle.py delete mode 100644 tests/unit/events/test_cursor_events.py create mode 100644 tests/unit/test_handle.py diff --git a/CHANGELOG.md b/CHANGELOG.md index dd6088f4..5ed8ce95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) - Use POSIX standard when creating location for the tables ([919](https://github.com/databricks/dbt-databricks/pull/919)) +- Clean up cursor management in the hopes of limiting issues with cancellation ([912](https://github.com/databricks/dbt-databricks/pull/912)) ## dbt-databricks 1.9.4 (Jan 30, 2024) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 3a6b4817..b536fda6 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -1,14 +1,9 @@ -import decimal import re -import sys import time -import uuid -import warnings -from collections.abc import Callable, Hashable, Iterator, Sequence +from collections.abc import Callable, Hashable, Iterator from contextlib import contextmanager from dataclasses import dataclass from multiprocessing.context import SpawnContext -from numbers import Number from typing import TYPE_CHECKING, Any, Optional, cast from dbt_common.events.contextvars import get_node_info @@ -16,9 +11,7 @@ from dbt_common.exceptions import DbtDatabaseError, DbtInternalError, DbtRuntimeError from dbt_common.utils import cast_to_str -import databricks.sql as dbsql -from databricks.sql.client import Connection as DatabricksSQLConnection -from databricks.sql.client import Cursor as DatabricksSQLCursor +from databricks.sql import __version__ as dbsql_version from databricks.sql.exc import Error from dbt.adapters.base.query_headers import MacroQueryStringSetter from dbt.adapters.contracts.connection import ( @@ -32,15 +25,13 @@ ) from dbt.adapters.databricks.__version__ import version as __version__ from dbt.adapters.databricks.api_client import DatabricksApiClient -from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider +from dbt.adapters.databricks.credentials import ( + DatabricksCredentials, + TCredentialProvider, +) from dbt.adapters.databricks.events.connection_events import ( ConnectionAcquire, - ConnectionCancel, - ConnectionCancelError, - ConnectionClose, - ConnectionCloseError, ConnectionCreate, - ConnectionCreated, ConnectionCreateError, ConnectionIdleCheck, ConnectionIdleClose, @@ -49,14 +40,8 @@ ConnectionRetrieve, ConnectionReuse, ) -from dbt.adapters.databricks.events.cursor_events import ( - CursorCancel, - CursorCancelError, - CursorClose, - CursorCloseError, - CursorCreate, -) from dbt.adapters.databricks.events.other_events import QueryError +from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker from dbt.adapters.databricks.utils import redact_credentials @@ -81,194 +66,19 @@ ) -DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)") - - # Number of idle seconds before a connection is automatically closed. Only applicable if # USE_LONG_SESSIONS is true. # Updated when idle times of 180s were causing errors DEFAULT_MAX_IDLE_TIME = 60 -class DatabricksSQLConnectionWrapper: - """Wrap a Databricks SQL connector in a way that no-ops transactions""" - - _conn: DatabricksSQLConnection - _is_cluster: bool - _cursors: list[DatabricksSQLCursor] - _creds: DatabricksCredentials - _user_agent: str - - def __init__( - self, - conn: DatabricksSQLConnection, - *, - is_cluster: bool, - creds: DatabricksCredentials, - user_agent: str, - ): - self._conn = conn - self._is_cluster = is_cluster - self._cursors = [] - self._creds = creds - self._user_agent = user_agent - - def cursor(self) -> "DatabricksSQLCursorWrapper": - cursor = self._conn.cursor() - - logger.debug(CursorCreate(cursor)) - - self._cursors.append(cursor) - return DatabricksSQLCursorWrapper( - cursor, - creds=self._creds, - user_agent=self._user_agent, - ) - - def cancel(self) -> None: - logger.debug(ConnectionCancel(self._conn)) - - cursors: list[DatabricksSQLCursor] = self._cursors - - for cursor in cursors: - try: - cursor.cancel() - except Error as exc: - logger.warning(ConnectionCancelError(self._conn, exc)) - - def close(self) -> None: - logger.debug(ConnectionClose(self._conn)) - - try: - self._conn.close() - except Error as exc: - logger.warning(ConnectionCloseError(self._conn, exc)) - - def rollback(self, *args: Any, **kwargs: Any) -> None: - logger.debug("NotImplemented: rollback") - - _dbr_version: tuple[int, int] - - @property - def dbr_version(self) -> tuple[int, int]: - if not hasattr(self, "_dbr_version"): - if self._is_cluster: - with self._conn.cursor() as cursor: - cursor.execute("SET spark.databricks.clusterUsageTags.sparkVersion") - results = cursor.fetchone() - if results: - dbr_version: str = results[1] - - m = DBR_VERSION_REGEX.search(dbr_version) - assert m, f"Unknown DBR version: {dbr_version}" - major = int(m.group(1)) - try: - minor = int(m.group(2)) - except ValueError: - minor = sys.maxsize - self._dbr_version = (major, minor) - else: - # Assuming SQL Warehouse uses the latest version. - self._dbr_version = (sys.maxsize, sys.maxsize) - - return self._dbr_version - - -class DatabricksSQLCursorWrapper: - """Wrap a Databricks SQL cursor in a way that no-ops transactions""" - - _cursor: DatabricksSQLCursor - _user_agent: str - _creds: DatabricksCredentials - - def __init__(self, cursor: DatabricksSQLCursor, creds: DatabricksCredentials, user_agent: str): - self._cursor = cursor - self._creds = creds - self._user_agent = user_agent - - def cancel(self) -> None: - logger.debug(CursorCancel(self._cursor)) - - try: - self._cursor.cancel() - except Error as exc: - logger.warning(CursorCancelError(self._cursor, exc)) - - def close(self) -> None: - logger.debug(CursorClose(self._cursor)) - - try: - self._cursor.close() - except Error as exc: - logger.warning(CursorCloseError(self._cursor, exc)) - - def fetchall(self) -> Sequence[tuple]: - return self._cursor.fetchall() - - def fetchone(self) -> Optional[tuple]: - return self._cursor.fetchone() - - def fetchmany(self, size: int) -> Sequence[tuple]: - return self._cursor.fetchmany(size) - - def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None: - # print(f"execute: {sql}") - if sql.strip().endswith(";"): - sql = sql.strip()[:-1] - if bindings is not None: - bindings = [self._fix_binding(binding) for binding in bindings] - self._cursor.execute(sql, bindings) - - @property - def hex_query_id(self) -> str: - """Return the hex GUID for this query - - This UUID can be tied back to the Databricks query history API - """ - if self._cursor.active_result_set: - _as_hex = uuid.UUID(bytes=self._cursor.active_result_set.command_id.operationId.guid) - return str(_as_hex) - return "" - - @classmethod - def _fix_binding(cls, value: Any) -> Any: - """Convert complex datatypes to primitives that can be loaded by - the Spark driver""" - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - - @property - def description(self) -> Optional[list[tuple]]: - return self._cursor.description - - def schemas(self, catalog_name: str, schema_name: Optional[str] = None) -> None: - self._cursor.schemas(catalog_name=catalog_name, schema_name=schema_name) - - def tables(self, catalog_name: str, schema_name: str, table_name: Optional[str] = None) -> None: - self._cursor.tables( - catalog_name=catalog_name, schema_name=schema_name, table_name=table_name - ) - - def __del__(self) -> None: - if self._cursor.open: - # This should not happen. The cursor should explicitly be closed. - logger.debug(CursorClose(self._cursor)) - - self._cursor.close() - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn("The cursor was closed by destructor.") - - DATABRICKS_QUERY_COMMENT = f""" {{%- set comment_dict = {{}} -%}} {{%- do comment_dict.update( app='dbt', dbt_version=dbt_version, dbt_databricks_version='{__version__}', - databricks_sql_connector_version='{dbsql.__version__}', + databricks_sql_connector_version='{dbsql_version}', profile_name=target.get('profile_name'), target_name=target.get('target_name'), ) -%}} @@ -292,11 +102,6 @@ def _get_comment_macro(self) -> Optional[str]: return self.config.query_comment.comment -@dataclass -class DatabricksAdapterResponse(AdapterResponse): - query_id: str = "" - - @dataclass(init=False) class DatabricksDBTConnection(Connection): last_used_time: Optional[float] = None @@ -374,7 +179,6 @@ def _reset_handle(self, open: Callable[[Connection], Connection]) -> None: class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" credentials_provider: Optional[TCredentialProvider] = None - _user_agent = f"dbt-databricks/{__version__}" def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): super().__init__(profile, mp_context) @@ -393,8 +197,8 @@ def cancel_open(self) -> list[str]: def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) - connection: DatabricksSQLConnectionWrapper = self.get_thread_connection().handle - dbr_version = connection.dbr_version + handle: DatabricksHandle = self.get_thread_connection().handle + dbr_version = handle.dbr_version return (dbr_version > version) - (dbr_version < version) def set_query_header(self, query_header_context: dict[str, Any]) -> None: @@ -464,7 +268,7 @@ def add_query( fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(sql): - cursor: Optional[DatabricksSQLCursorWrapper] = None + cursor: Optional[CursorWrapper] = None try: log_sql = redact_credentials(sql) if abridge_sql_log: @@ -480,12 +284,12 @@ def add_query( pre = time.time() - cursor = cast(DatabricksSQLConnectionWrapper, connection.handle).cursor() - cursor.execute(sql, bindings) + handle: DatabricksHandle = connection.handle + cursor = handle.execute(sql, bindings) fire_event( SQLQueryStatus( - status=str(self.get_response(cursor)), + status=str(cursor.get_response()), elapsed=round((time.time() - pre), 2), node_info=get_node_info(), ) @@ -493,9 +297,7 @@ def add_query( return connection, cursor except Error: - if cursor is not None: - cursor.close() - cursor = None + close_cursor = True raise finally: if close_cursor and cursor is not None: @@ -507,11 +309,11 @@ def execute( auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None, - ) -> tuple[DatabricksAdapterResponse, "Table"]: + ) -> tuple[AdapterResponse, "Table"]: sql = self._add_query_comment(sql) _, cursor = self.add_query(sql, auto_begin) try: - response = self.get_response(cursor) + response = cursor.get_response() if fetch: table = self.get_result_from_cursor(cursor, limit) else: @@ -523,15 +325,15 @@ def execute( finally: cursor.close() - def _execute_cursor( - self, log_sql: str, f: Callable[[DatabricksSQLCursorWrapper], None] + def _execute_with_cursor( + self, log_sql: str, f: Callable[[DatabricksHandle], CursorWrapper] ) -> "Table": connection = self.get_thread_connection() fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(log_sql): - cursor: Optional[DatabricksSQLCursorWrapper] = None + cursor: Optional[CursorWrapper] = None try: fire_event( SQLQuery( @@ -543,9 +345,8 @@ def _execute_cursor( pre = time.time() - handle: DatabricksSQLConnectionWrapper = connection.handle - cursor = handle.cursor() - f(cursor) + handle: DatabricksHandle = connection.handle + cursor = f(handle) fire_event( SQLQueryStatus( @@ -557,28 +358,24 @@ def _execute_cursor( return self.get_result_from_cursor(cursor, None) finally: - if cursor is not None: + if cursor: cursor.close() def list_schemas(self, database: str, schema: Optional[str] = None) -> "Table": database = database.strip("`") if schema: schema = schema.strip("`").lower() - return self._execute_cursor( + return self._execute_with_cursor( f"GetSchemas(database={database}, schema={schema})", - lambda cursor: cursor.schemas(catalog_name=database, schema_name=schema), + lambda cursor: cursor.list_schemas(database=database, schema=schema), ) - def list_tables(self, database: str, schema: str, identifier: Optional[str] = None) -> "Table": + def list_tables(self, database: str, schema: str) -> "Table": database = database.strip("`") schema = schema.strip("`").lower() - if identifier: - identifier = identifier.strip("`") - return self._execute_cursor( - f"GetTables(database={database}, schema={schema}, identifier={identifier})", - lambda cursor: cursor.tables( - catalog_name=database, schema_name=schema, table_name=identifier - ), + return self._execute_with_cursor( + f"GetTables(database={database}, schema={schema})", + lambda cursor: cursor.list_tables(database=database, schema=schema), ) # override @@ -609,20 +406,6 @@ def cleanup_all(self) -> None: self.thread_connections.clear() self.threads_compute_connections.clear() - @classmethod - def get_open_for_context( - cls, query_header_context: Any = None - ) -> Callable[[Connection], Connection]: - # If there is no node we can simply return the exsting class method open. - # If there is a node create a closure that will call cls._open with the node. - if not query_header_context: - return cls.open - - def open_for_model(connection: Connection) -> Connection: - return cls._open(connection, query_header_context) - - return open_for_model - @classmethod def open(cls, connection: Connection) -> Connection: databricks_connection = cast(DatabricksDBTConnection, connection) @@ -635,119 +418,23 @@ def open(cls, connection: Connection) -> Connection: # gotta keep this so we don't prompt users many times cls.credentials_provider = creds.authenticate(cls.credentials_provider) - - invocation_env = creds.get_invocation_env() - user_agent_entry = cls._user_agent - if invocation_env: - user_agent_entry = f"{cls._user_agent}; {invocation_env}" - - connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] - - http_headers: list[tuple[str, str]] = list( - creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() + conn_args = SqlUtils.prepare_connection_arguments( + creds, cls.credentials_provider, databricks_connection.http_path ) - # If a model specifies a compute resource the http path - # may be different than the http_path property of creds. - http_path = databricks_connection.http_path - - def connect() -> DatabricksSQLConnectionWrapper: + def connect() -> DatabricksHandle: try: # TODO: what is the error when a user specifies a catalog they don't have access to - conn = dbsql.connect( - server_hostname=creds.host, - http_path=http_path, - credentials_provider=cls.credentials_provider, - http_headers=http_headers if http_headers else None, - session_configuration=creds.session_properties, - catalog=creds.database, - use_inline_params="silent", - # schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL. - _user_agent_entry=user_agent_entry, - **connection_parameters, + conn = DatabricksHandle.from_connection_args( + conn_args, creds.cluster_id is not None ) - if conn: - databricks_connection.session_id = conn.get_session_id_hex() - databricks_connection.last_used_time = time.time() - logger.debug(ConnectionCreated(str(databricks_connection))) - - return DatabricksSQLConnectionWrapper( - conn, - is_cluster=creds.cluster_id is not None, - creds=creds, - user_agent=user_agent_entry, - ) - except Error as exc: - logger.error(ConnectionCreateError(exc)) - raise - - def exponential_backoff(attempt: int) -> int: - return attempt * attempt - - retryable_exceptions = [] - # this option is for backwards compatibility - if creds.retry_all: - retryable_exceptions = [Error] - - return cls.retry_connection( - connection, - connect=connect, - logger=logger, - retryable_exceptions=retryable_exceptions, - retry_limit=creds.connect_retries, - retry_timeout=(timeout if timeout is not None else exponential_backoff), - ) - - @classmethod - def _open(cls, connection: Connection, query_header_context: Any = None) -> Connection: - if connection.state == ConnectionState.OPEN: - return connection - - creds: DatabricksCredentials = connection.credentials - timeout = creds.connect_timeout + databricks_connection.session_id = conn.session_id + databricks_connection.last_used_time = time.time() - # gotta keep this so we don't prompt users many times - cls.credentials_provider = creds.authenticate(cls.credentials_provider) - - invocation_env = creds.get_invocation_env() - user_agent_entry = cls._user_agent - if invocation_env: - user_agent_entry = f"{cls._user_agent}; {invocation_env}" - - connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] - - http_headers: list[tuple[str, str]] = list( - creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() - ) - - # If a model specifies a compute resource the http path - # may be different than the http_path property of creds. - http_path = _get_http_path(query_header_context, creds) - - def connect() -> DatabricksSQLConnectionWrapper: - try: - # TODO: what is the error when a user specifies a catalog they don't have access to - conn: DatabricksSQLConnection = dbsql.connect( - server_hostname=creds.host, - http_path=http_path, - credentials_provider=cls.credentials_provider, - http_headers=http_headers if http_headers else None, - session_configuration=creds.session_properties, - catalog=creds.database, - use_inline_params="silent", - # schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL. - _user_agent_entry=user_agent_entry, - **connection_parameters, - ) - logger.debug(ConnectionCreated(str(conn))) - - return DatabricksSQLConnectionWrapper( - conn, - is_cluster=creds.cluster_id is not None, - creds=creds, - user_agent=user_agent_entry, - ) + return conn + else: + raise DbtDatabaseError("Failed to create connection") except Error as exc: logger.error(ConnectionCreateError(exc)) raise @@ -780,15 +467,11 @@ def close(cls, connection: Connection) -> Connection: return connection @classmethod - def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse: - _query_id = getattr(cursor, "hex_query_id", None) - if cursor is None: - logger.debug("No cursor was provided. Query ID not available.") - query_id = "N/A" + def get_response(cls, cursor: Any) -> AdapterResponse: + if isinstance(cursor, CursorWrapper): + return cursor.get_response() else: - query_id = _query_id - message = "OK" - return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore + return AdapterResponse("OK") def get_thread_connection(self) -> Connection: conn = super().get_thread_connection() @@ -832,7 +515,7 @@ def _cleanup_idle_connections(self) -> None: ) and conn._idle_too_long(): logger.debug(ConnectionIdleClose(str(conn))) self.close(conn) - conn._reset_handle(self._open) + conn._reset_handle(self.open) def _create_compute_connection( self, conn_name: str, query_header_context: Any = None @@ -982,7 +665,7 @@ def _get_max_idle_time(query_header_context: Any, creds: DatabricksCredentials) "connect_max_idle", max_idle_time ) - if not isinstance(max_idle_time, Number): + if not isinstance(max_idle_time, int): if isinstance(max_idle_time, str) and max_idle_time.strip().isnumeric(): return int(max_idle_time.strip()) else: diff --git a/dbt/adapters/databricks/events/connection_events.py b/dbt/adapters/databricks/events/connection_events.py index 9f8ec8c1..3f533c1d 100644 --- a/dbt/adapters/databricks/events/connection_events.py +++ b/dbt/adapters/databricks/events/connection_events.py @@ -17,30 +17,6 @@ def __str__(self) -> str: return f"Connection(session-id={self.session_id}) - {self.message}" -class ConnectionCancel(ConnectionEvent): - def __init__(self, connection: Optional[Connection]): - super().__init__(connection, "Cancelling connection") - - -class ConnectionClose(ConnectionEvent): - def __init__(self, connection: Optional[Connection]): - super().__init__(connection, "Closing connection") - - -class ConnectionCancelError(ConnectionEvent): - def __init__(self, connection: Optional[Connection], exception: Exception): - super().__init__( - connection, str(SQLErrorEvent(exception, "Exception while trying to cancel connection")) - ) - - -class ConnectionCloseError(ConnectionEvent): - def __init__(self, connection: Optional[Connection], exception: Exception): - super().__init__( - connection, str(SQLErrorEvent(exception, "Exception while trying to close connection")) - ) - - class ConnectionCreateError(ConnectionEvent): def __init__(self, exception: Exception): super().__init__( diff --git a/dbt/adapters/databricks/events/cursor_events.py b/dbt/adapters/databricks/events/cursor_events.py deleted file mode 100644 index d94a002a..00000000 --- a/dbt/adapters/databricks/events/cursor_events.py +++ /dev/null @@ -1,59 +0,0 @@ -from abc import ABC -from uuid import UUID - -from databricks.sql.client import Cursor - -from dbt.adapters.databricks.events.base import SQLErrorEvent - - -class CursorEvent(ABC): - def __init__(self, cursor: Cursor, message: str): - self.message = message - self.session_id = "Unknown" - self.command_id = "Unknown" - if cursor: - if cursor.connection: - self.session_id = cursor.connection.get_session_id_hex() - if ( - cursor.active_result_set - and cursor.active_result_set.command_id - and cursor.active_result_set.command_id.operationId - ): - self.command_id = ( - str(UUID(bytes=cursor.active_result_set.command_id.operationId.guid)) - or "Unknown" - ) - - def __str__(self) -> str: - return ( - f"Cursor(session-id={self.session_id}, command-id={self.command_id}) - {self.message}" - ) - - -class CursorCloseError(CursorEvent): - def __init__(self, cursor: Cursor, exception: Exception): - super().__init__( - cursor, str(SQLErrorEvent(exception, "Exception while trying to close cursor")) - ) - - -class CursorCancelError(CursorEvent): - def __init__(self, cursor: Cursor, exception: Exception): - super().__init__( - cursor, str(SQLErrorEvent(exception, "Exception while trying to cancel cursor")) - ) - - -class CursorCreate(CursorEvent): - def __init__(self, cursor: Cursor): - super().__init__(cursor, "Created cursor") - - -class CursorClose(CursorEvent): - def __init__(self, cursor: Cursor): - super().__init__(cursor, "Closing cursor") - - -class CursorCancel(CursorEvent): - def __init__(self, cursor: Cursor): - super().__init__(cursor, "Cancelling cursor") diff --git a/dbt/adapters/databricks/global_state.py b/dbt/adapters/databricks/global_state.py index cdc5df98..e01bb0ce 100644 --- a/dbt/adapters/databricks/global_state.py +++ b/dbt/adapters/databricks/global_state.py @@ -1,6 +1,8 @@ import os from typing import ClassVar, Optional +from dbt.adapters.databricks.__version__ import version as __version__ + class GlobalState: """Global state is a bad idea, but since we don't control instantiation, better to have it in a @@ -10,6 +12,8 @@ class GlobalState: __invocation_env: ClassVar[Optional[str]] = None __invocation_env_set: ClassVar[bool] = False + USER_AGENT = f"dbt-databricks/{__version__}" + @classmethod def get_invocation_env(cls) -> Optional[str]: if not cls.__invocation_env_set: diff --git a/dbt/adapters/databricks/handle.py b/dbt/adapters/databricks/handle.py new file mode 100644 index 00000000..1b7d30db --- /dev/null +++ b/dbt/adapters/databricks/handle.py @@ -0,0 +1,324 @@ +import decimal +import re +import sys +from collections.abc import Callable, Sequence +from types import TracebackType +from typing import TYPE_CHECKING, Any, Optional, TypeVar + +from dbt_common.exceptions import DbtRuntimeError + +import databricks.sql as dbsql +from databricks.sql.client import Connection, Cursor +from dbt.adapters.contracts.connection import AdapterResponse +from dbt.adapters.databricks import utils +from dbt.adapters.databricks.__version__ import version as __version__ +from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider +from dbt.adapters.databricks.logging import logger + +if TYPE_CHECKING: + pass + + +CursorOp = Callable[[Cursor], None] +CursorExecOp = Callable[[Cursor], Cursor] +CursorWrapperOp = Callable[["CursorWrapper"], None] +ConnectionOp = Callable[[Optional[Connection]], None] +LogOp = Callable[[], str] +FailLogOp = Callable[[Exception], str] + + +class CursorWrapper: + """ + Wrap the DBSQL cursor to abstract the details from DatabricksConnectionManager. + """ + + def __init__(self, cursor: Cursor): + self._cursor = cursor + self.open = True + + @property + def description(self) -> Optional[list[tuple]]: + return self._cursor.description + + def cancel(self) -> None: + if self._cursor.active_op_handle: + self._cleanup( + lambda cursor: cursor.cancel(), + lambda: f"{self} - Cancelling", + lambda ex: f"{self} - Exception while cancelling: {ex}", + ) + + def close(self) -> None: + self._cleanup( + lambda cursor: cursor.close(), + lambda: f"{self} - Closing", + lambda ex: f"{self} - Exception while closing: {ex}", + ) + + def _cleanup( + self, + cleanup: CursorOp, + startLog: LogOp, + failLog: FailLogOp, + ) -> None: + """ + Common cleanup function for cursor operations, handling either close or cancel. + """ + if self.open: + self.open = False + logger.debug(startLog()) + utils.handle_exceptions_as_warning(lambda: cleanup(self._cursor), failLog) + + def fetchall(self) -> Sequence[tuple]: + return self._safe_execute(lambda cursor: cursor.fetchall()) + + def fetchone(self) -> Optional[tuple]: + return self._safe_execute(lambda cursor: cursor.fetchone()) + + def fetchmany(self, size: int) -> Sequence[tuple]: + return self._safe_execute(lambda cursor: cursor.fetchmany(size)) + + def get_response(self) -> AdapterResponse: + return AdapterResponse(_message="OK", query_id=self._cursor.query_id or "N/A") + + T = TypeVar("T") + + def _safe_execute(self, f: Callable[[Cursor], T]) -> T: + if not self.open: + raise DbtRuntimeError("Attempting to execute on a closed cursor") + return f(self._cursor) + + def __str__(self) -> str: + return ( + f"Cursor(session-id={self._cursor.connection.get_session_id_hex()}, " + f"command-id={self._cursor.query_id})" + ) + + def __enter__(self) -> "CursorWrapper": + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + self.close() + return exc_val is None + + +class DatabricksHandle: + """ + Handle for a Databricks SQL Session. + Provides a layer of abstraction over the Databricks SQL client library such that + DatabricksConnectionManager does not depend on the details of this library directly. + """ + + def __init__( + self, + conn: Connection, + is_cluster: bool, + ): + self._conn = conn + self.open = True + self._cursor: Optional[CursorWrapper] = None + self._dbr_version: Optional[tuple[int, int]] = None + self._is_cluster = is_cluster + + @property + def dbr_version(self) -> tuple[int, int]: + """ + Gets the DBR version of the current session. + """ + if not self._dbr_version: + if self._is_cluster: + cursor = self._safe_execute( + lambda cursor: cursor.execute( + "SET spark.databricks.clusterUsageTags.sparkVersion" + ) + ) + results = cursor.fetchone() + self._dbr_version = SqlUtils.extract_dbr_version(results[1] if results else "") + cursor.close() + else: + # Assuming SQL Warehouse uses the latest version. + self._dbr_version = (sys.maxsize, sys.maxsize) + + return self._dbr_version + + @property + def session_id(self) -> str: + return self._conn.get_session_id_hex() + + def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> CursorWrapper: + """ + Execute a SQL statement on the current session with optional bindings. + """ + return self._safe_execute( + lambda cursor: cursor.execute( + SqlUtils.clean_sql(sql), SqlUtils.translate_bindings(bindings) + ) + ) + + def list_schemas(self, database: str, schema: Optional[str] = None) -> CursorWrapper: + """ + Get a cursor for listing schemas in the given database. + """ + return self._safe_execute(lambda cursor: cursor.schemas(database, schema)) + + def list_tables(self, database: str, schema: str) -> "CursorWrapper": + """ + Get a cursor for listing tables in the given database and schema. + """ + + return self._safe_execute(lambda cursor: cursor.tables(database, schema)) + + def cancel(self) -> None: + """ + Cancel in progress query, if any, then close connection and cursor. + """ + self._cleanup( + lambda cursor: cursor.cancel(), + lambda: f"{self} - Cancelling", + lambda ex: f"{self} - Exception while cancelling: {ex}", + ) + + def close(self) -> None: + """ + Close the connection and cursor. + """ + + self._cleanup( + lambda cursor: cursor.close(), + lambda: f"{self} - Closing", + lambda ex: f"{self} - Exception while closing: {ex}", + ) + + def rollback(self) -> None: + """ + Required for interface compatibility, but not implemented. + """ + logger.debug("NotImplemented: rollback") + + @staticmethod + def from_connection_args( + conn_args: dict[str, Any], is_cluster: bool + ) -> Optional["DatabricksHandle"]: + """ + Create a new DatabricksHandle from the given connection arguments. + """ + + conn = dbsql.connect(**conn_args) + if not conn: + logger.warning(f"Failed to create connection for {conn_args.get('http_path')}") + return None + connection = DatabricksHandle(conn, is_cluster=is_cluster) + logger.debug(f"{connection} - Created") + + return connection + + def _cleanup( + self, + cursor_op: CursorWrapperOp, + startLog: LogOp, + failLog: FailLogOp, + ) -> None: + """ + Function for cleaning up the connection and cursor, handling either close or cancel. + """ + if self.open: + self.open = False + logger.debug(startLog()) + + if self._cursor: + cursor_op(self._cursor) + + utils.handle_exceptions_as_warning(lambda: self._conn.close(), failLog) + + def _safe_execute(self, f: CursorExecOp) -> CursorWrapper: + """ + Ensure that a previously opened cursor is closed and that a new one is created + before executing the given function. + Also ensures that we do not continue to execute SQL after a connection cleanup + has been requested. + """ + + if not self.open: + raise DbtRuntimeError("Attempting to execute on a closed connection") + assert self._conn, "Should not be possible for _conn to be None if open" + if self._cursor: + self._cursor.close() + self._cursor = CursorWrapper(f(self._conn.cursor())) + return self._cursor + + def __del__(self) -> None: + if self._cursor: + self._cursor.close() + + self.close() + + def __str__(self) -> str: + return f"Connection(session-id={self.session_id})" + + +class SqlUtils: + """ + Utility class for preparing cursor input/output. + """ + + DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)") + user_agent = f"dbt-databricks/{__version__}" + + @staticmethod + def extract_dbr_version(version: str) -> tuple[int, int]: + m = SqlUtils.DBR_VERSION_REGEX.search(version) + if m: + major = int(m.group(1)) + if m.group(2) == "x": + minor = sys.maxsize + else: + minor = int(m.group(2)) + return (major, minor) + else: + raise DbtRuntimeError("Failed to detect DBR version") + + @staticmethod + def translate_bindings(bindings: Optional[Sequence[Any]]) -> Optional[Sequence[Any]]: + if bindings: + return list(map(lambda x: float(x) if isinstance(x, decimal.Decimal) else x, bindings)) + return None + + @staticmethod + def clean_sql(sql: str) -> str: + cleaned = sql.strip() + if cleaned.endswith(";"): + cleaned = cleaned[:-1] + return cleaned + + @staticmethod + def prepare_connection_arguments( + creds: DatabricksCredentials, creds_provider: TCredentialProvider, http_path: str + ) -> dict[str, Any]: + invocation_env = creds.get_invocation_env() + user_agent_entry = SqlUtils.user_agent + if invocation_env: + user_agent_entry += f"; {invocation_env}" + + connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] + + http_headers: list[tuple[str, str]] = list( + creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() + ) + + return { + "server_hostname": creds.host, + "http_path": http_path, + "credentials_provider": creds_provider, + "http_headers": http_headers if http_headers else None, + "session_configuration": creds.session_properties, + "catalog": creds.database, + "use_inline_params": "silent", + "schema": creds.schema, + "_user_agent_entry": user_agent_entry, + **connection_parameters, + } diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 8adef638..c30015cd 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -625,9 +625,9 @@ def add_query( def run_sql_for_tests( self, sql: str, fetch: str, conn: Connection ) -> Optional[Union[Optional[tuple], list[tuple]]]: - cursor = conn.handle.cursor() + handle = conn.handle try: - cursor.execute(sql) + cursor = handle.execute(sql) if fetch == "one": return cursor.fetchone() elif fetch == "all": @@ -639,7 +639,8 @@ def run_sql_for_tests( print(e) raise finally: - cursor.close() + if cursor: + cursor.close() conn.transaction_open = False @available diff --git a/dbt/adapters/databricks/utils.py b/dbt/adapters/databricks/utils.py index 3dfd4096..dccdd16c 100644 --- a/dbt/adapters/databricks/utils.py +++ b/dbt/adapters/databricks/utils.py @@ -6,6 +6,7 @@ from jinja2 import Undefined from dbt.adapters.base import BaseAdapter +from dbt.adapters.databricks.logging import logger from dbt.adapters.spark.impl import TABLE_OR_VIEW_NOT_FOUND_MESSAGES if TYPE_CHECKING: @@ -32,7 +33,7 @@ def _redact_credentials_in_copy_into(sql: str) -> str: f"{key.strip()} = '[REDACTED]'" for key, _ in (pair.strip().split("=", 1) for pair in m.group(1).split(",")) ) - return f"{sql[: m.start()]} ({redacted}){sql[m.end():]}" + return f"{sql[: m.start()]} ({redacted}){sql[m.end() :]}" else: return sql @@ -77,3 +78,13 @@ def handle_missing_objects(exec: Callable[[], T], default: T) -> T: def quote(name: str) -> str: return f"`{name}`" + + +ExceptionToStrOp = Callable[[Exception], str] + + +def handle_exceptions_as_warning(op: Callable[[], None], log_gen: ExceptionToStrOp) -> None: + try: + op() + except Exception as e: + logger.warning(log_gen(e)) diff --git a/tests/unit/events/test_connection_events.py b/tests/unit/events/test_connection_events.py index 81a30c1d..6c005fb6 100644 --- a/tests/unit/events/test_connection_events.py +++ b/tests/unit/events/test_connection_events.py @@ -2,7 +2,6 @@ from dbt.adapters.databricks.events.connection_events import ( ConnectionAcquire, - ConnectionCloseError, ConnectionEvent, ) @@ -30,16 +29,6 @@ def test_connection_event__with_id(self): assert str(event) == "Connection(session-id=1234) - This is a test" -class TestConnectionCloseError: - def test_connection_close_error(self): - e = Exception("This is an exception") - event = ConnectionCloseError(None, e) - assert ( - str(event) == "Connection(session-id=Unknown) " - "- Exception while trying to close connection: This is an exception" - ) - - class TestConnectionAcquire: def test_connection_acquire__missing_data(self): event = ConnectionAcquire(None, None, None, (0, 0)) diff --git a/tests/unit/events/test_cursor_events.py b/tests/unit/events/test_cursor_events.py deleted file mode 100644 index e492415a..00000000 --- a/tests/unit/events/test_cursor_events.py +++ /dev/null @@ -1,41 +0,0 @@ -from unittest.mock import Mock - -from dbt.adapters.databricks.events.cursor_events import CursorCloseError, CursorEvent - - -class CursorTestEvent(CursorEvent): - def __init__(self, cursor): - super().__init__(cursor, "This is a test") - - -class TestCursorEvents: - def test_cursor_event__no_cursor(self): - event = CursorTestEvent(None) - assert str(event) == "Cursor(session-id=Unknown, command-id=Unknown) - This is a test" - - def test_cursor_event__no_ids(self): - mock = Mock() - mock.connection = None - mock.active_result_set = None - event = CursorTestEvent(mock) - assert str(event) == "Cursor(session-id=Unknown, command-id=Unknown) - This is a test" - - def test_cursor_event__with_ids(self): - mock = Mock() - mock.connection.get_session_id_hex.return_value = "1234" - mock.active_result_set.command_id.operationId.guid = (1234).to_bytes(16, "big") - event = CursorTestEvent(mock) - assert ( - str(event) == "Cursor(session-id=1234, command-id=00000000-0000-0000-0000-0000000004d2)" - " - This is a test" - ) - - -class TestCursorCloseError: - def test_cursor_close_error(self): - e = Exception("This is an exception") - event = CursorCloseError(None, e) - assert ( - str(event) == "Cursor(session-id=Unknown, command-id=Unknown) " - "- Exception while trying to close cursor: This is an exception" - ) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index d42fa5e1..f8c3b4a1 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -77,8 +77,7 @@ def test_two_catalog_settings(self): ) expected_message = ( - "Got duplicate keys: (`databricks.catalog` in session_properties)" - ' all map to "database"' + 'Got duplicate keys: (`databricks.catalog` in session_properties) all map to "database"' ) assert expected_message in str(excinfo.value) @@ -126,7 +125,7 @@ def test_custom_user_agent(self): adapter = DatabricksAdapter(config, get_context("spawn")) with patch( - "dbt.adapters.databricks.connections.dbsql.connect", + "dbt.adapters.databricks.handle.dbsql.connect", new=self._connect_func(expected_invocation_env="databricks-workflows"), ): with patch( @@ -189,7 +188,7 @@ def _test_environment_http_headers( adapter = DatabricksAdapter(config, get_context("spawn")) with patch( - "dbt.adapters.databricks.connections.dbsql.connect", + "dbt.adapters.databricks.handle.dbsql.connect", new=self._connect_func(expected_http_headers=expected_http_headers), ): with patch( @@ -206,7 +205,7 @@ def test_oauth_settings(self): adapter = DatabricksAdapter(config, get_context("spawn")) with patch( - "dbt.adapters.databricks.connections.dbsql.connect", + "dbt.adapters.databricks.handle.dbsql.connect", new=self._connect_func(expected_no_token=True), ): connection = adapter.acquire_connection("dummy") @@ -219,7 +218,7 @@ def test_client_creds_settings(self): adapter = DatabricksAdapter(config, get_context("spawn")) with patch( - "dbt.adapters.databricks.connections.dbsql.connect", + "dbt.adapters.databricks.handle.dbsql.connect", new=self._connect_func(expected_client_creds=True), ): connection = adapter.acquire_connection("dummy") @@ -268,6 +267,7 @@ def connect( assert http_headers is None else: assert http_headers == expected_http_headers + return Mock() return connect @@ -278,7 +278,7 @@ def _test_databricks_sql_connector_connection(self, connect): config = self._get_config() adapter = DatabricksAdapter(config, get_context("spawn")) - with patch("dbt.adapters.databricks.connections.dbsql.connect", new=connect): + with patch("dbt.adapters.databricks.handle.dbsql.connect", new=connect): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -302,7 +302,7 @@ def _test_databricks_sql_connector_catalog_connection(self, connect): config = self._get_config() adapter = DatabricksAdapter(config, get_context("spawn")) - with patch("dbt.adapters.databricks.connections.dbsql.connect", new=connect): + with patch("dbt.adapters.databricks.handle.dbsql.connect", new=connect): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -329,7 +329,7 @@ def _test_databricks_sql_connector_http_header_connection(self, http_headers, co config = self._get_config(connection_parameters={"http_headers": http_headers}) adapter = DatabricksAdapter(config, get_context("spawn")) - with patch("dbt.adapters.databricks.connections.dbsql.connect", new=connect): + with patch("dbt.adapters.databricks.handle.dbsql.connect", new=connect): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load diff --git a/tests/unit/test_handle.py b/tests/unit/test_handle.py new file mode 100644 index 00000000..01326063 --- /dev/null +++ b/tests/unit/test_handle.py @@ -0,0 +1,211 @@ +import sys +from decimal import Decimal +from unittest.mock import Mock + +import pytest +from databricks.sql.client import Cursor +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.contracts.connection import AdapterResponse +from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils + + +class TestSqlUtils: + @pytest.mark.parametrize( + "bindings, expected", [(None, None), ([1], [1]), ([1, Decimal(0.73)], [1, 0.73])] + ) + def test_translate_bindings(self, bindings, expected): + assert SqlUtils.translate_bindings(bindings) == expected + + @pytest.mark.parametrize( + "sql, expected", [(" select 1; ", "select 1"), ("select 1", "select 1")] + ) + def test_clean_sql(self, sql, expected): + assert SqlUtils.clean_sql(sql) == expected + + @pytest.mark.parametrize("result, expected", [("14.x", (14, sys.maxsize)), ("12.1", (12, 1))]) + def test_extract_dbr_version(self, result, expected): + assert SqlUtils.extract_dbr_version(result) == expected + + def test_extract_dbr_version__invalid(self): + with pytest.raises(DbtRuntimeError): + SqlUtils.extract_dbr_version("foo") + + +class TestCursorWrapper: + @pytest.fixture + def cursor(self): + return Mock() + + def test_description(self, cursor): + cursor.description = [("foo", "bar")] + wrapper = CursorWrapper(cursor) + assert wrapper.description == [("foo", "bar")] + + def test_cancel__closed(self, cursor): + wrapper = CursorWrapper(cursor) + wrapper.open = False + wrapper.cancel() + cursor.cancel.assert_not_called() + + def test_cancel__open_no_result_set(self, cursor): + wrapper = CursorWrapper(cursor) + cursor.active_result_set = None + wrapper.cancel() + assert wrapper.open is False + + def test_cancel__open_with_result_set(self, cursor): + wrapper = CursorWrapper(cursor) + wrapper.cancel() + cursor.cancel.assert_called_once() + + def test_cancel__error_cancelling(self, cursor): + cursor.cancel.side_effect = Exception("foo") + wrapper = CursorWrapper(cursor) + wrapper.cancel() + cursor.cancel.assert_called_once() + + def test_closed__closed(self, cursor): + wrapper = CursorWrapper(cursor) + wrapper.open = False + wrapper.close() + cursor.close.assert_not_called() + + def test_closed__open(self, cursor): + wrapper = CursorWrapper(cursor) + cursor.active_result_set = None + wrapper.close() + assert wrapper.open is False + + def test_close__error_closing(self, cursor): + cursor.close.side_effect = Exception("foo") + wrapper = CursorWrapper(cursor) + wrapper.close() + cursor.close.assert_called_once() + + def test_fetchone(self, cursor): + cursor.fetchone.return_value = [("foo", "bar")] + wrapper = CursorWrapper(cursor) + assert wrapper.fetchone() == [("foo", "bar")] + + def test_fetchall(self, cursor): + cursor.fetchall.return_value = [("foo", "bar")] + wrapper = CursorWrapper(cursor) + assert wrapper.fetchall() == [("foo", "bar")] + + def test_fetchmany(self, cursor): + cursor.fetchmany.return_value = [("foo", "bar")] + wrapper = CursorWrapper(cursor) + assert wrapper.fetchmany(1) == [("foo", "bar")] + + def test_get_response__no_query_id(self, cursor): + cursor.query_id = None + wrapper = CursorWrapper(cursor) + assert wrapper.get_response() == AdapterResponse("OK", query_id="N/A") + + def test_get_response__with_query_id(self, cursor): + cursor.query_id = "id" + wrapper = CursorWrapper(cursor) + assert wrapper.get_response() == AdapterResponse("OK", query_id="id") + + def test_with__no_exception(self, cursor): + with CursorWrapper(cursor) as c: + c.fetchone() + cursor.fetchone.assert_called_once() + cursor.close.assert_called_once() + + def test_with__exception(self, cursor): + cursor.fetchone.side_effect = Exception("foo") + with pytest.raises(Exception, match="foo"): + with CursorWrapper(cursor) as c: + c.fetchone() + cursor.fetchone.assert_called_once() + cursor.close.assert_called_once() + + +class TestDatabricksHandle: + @pytest.fixture + def conn(self): + return Mock() + + @pytest.fixture + def cursor(self): + return Mock() + + def test_safe_execute__closed(self, conn): + handle = DatabricksHandle(conn, True) + handle.open = False + with pytest.raises(DbtRuntimeError, match="Attempting to execute on a closed connection"): + handle._safe_execute(Mock()) + + def test_safe_execute__with_cursor(self, conn, cursor): + new_cursor = Mock() + + def f(_: Cursor) -> Cursor: + return new_cursor + + handle = DatabricksHandle(conn, True) + handle._cursor = cursor + assert handle._safe_execute(f)._cursor == new_cursor + assert handle._cursor._cursor == new_cursor + cursor.close.assert_called_once() + + def test_safe_execute__without_cursor(self, conn): + new_cursor = Mock() + + def f(_: Cursor) -> Cursor: + return new_cursor + + handle = DatabricksHandle(conn, True) + assert handle._safe_execute(f)._cursor == new_cursor + assert handle._cursor._cursor == new_cursor + + def test_cancel__closed(self, conn): + handle = DatabricksHandle(conn, True) + handle.open = False + handle.cancel() + conn.close.assert_not_called() + + def test_cancel__open_no_cursor(self, conn): + handle = DatabricksHandle(conn, True) + handle.cancel() + conn.close.assert_called_once() + + def test_cancel__open_cursor(self, conn, cursor): + handle = DatabricksHandle(conn, True) + handle._cursor = cursor + handle.cancel() + cursor.cancel.assert_called_once() + conn.close.assert_called_once() + + def test_cancel__open_raising_exception(self, conn): + conn.close.side_effect = Exception("foo") + handle = DatabricksHandle(conn, True) + handle.cancel() + conn.close.assert_called_once() + + def test_close__closed(self, conn): + handle = DatabricksHandle(conn, True) + handle.open = False + handle.close() + conn.close.assert_not_called() + + def test_close__open_no_cursor(self, conn): + handle = DatabricksHandle(conn, True) + handle.close() + conn.close.assert_called_once() + + def test_close__open_cursor(self, conn, cursor): + handle = DatabricksHandle(conn, True) + handle._cursor = cursor + handle.close() + cursor.close.assert_called_once() + conn.close.assert_called_once() + + def test_close__open_raising_exception(self, conn, cursor): + conn.close.side_effect = Exception("foo") + handle = DatabricksHandle(conn, True) + handle._cursor = cursor + handle.close() + cursor.close.assert_called_once() + conn.close.assert_called_once() diff --git a/tests/unit/test_idle_config.py b/tests/unit/test_idle_config.py index de354568..7115554c 100644 --- a/tests/unit/test_idle_config.py +++ b/tests/unit/test_idle_config.py @@ -10,7 +10,7 @@ class TestDatabricksConnectionMaxIdleTime: """Test the various cases for determining a specified warehouse.""" errMsg = ( - "Compute resource foo does not exist or does not specify http_path, " "relation: a_relation" + "Compute resource foo does not exist or does not specify http_path, relation: a_relation" ) def test_get_max_idle_default(self): @@ -194,14 +194,14 @@ def test_get_max_idle_invalid(self): with pytest.raises(DbtRuntimeError) as info: connections._get_max_idle_time(node, creds) assert ( - "1.2.3 is not a valid value for connect_max_idle. " "Must be a number of seconds." + "1.2.3 is not a valid value for connect_max_idle. Must be a number of seconds." ) in str(info.value) creds.compute["alternate_compute"]["connect_max_idle"] = "1,002.3" with pytest.raises(DbtRuntimeError) as info: connections._get_max_idle_time(node, creds) assert ( - "1,002.3 is not a valid value for connect_max_idle. " "Must be a number of seconds." + "1,002.3 is not a valid value for connect_max_idle. Must be a number of seconds." ) in str(info.value) def test_get_max_idle_simple_string_conversion(self): From d2d1d7afc9e63b409fdaeb376a99c6116381c8b0 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Thu, 6 Feb 2025 09:57:13 -0800 Subject: [PATCH 15/21] Unblocking iceberg snapshots, reactivating iceberg tests (#930) --- CHANGELOG.md | 4 ++++ dbt/adapters/databricks/impl.py | 5 +++-- .../adapter/iceberg/test_iceberg_support.py | 6 ++---- .../adapter/simple_snapshot/test_snapshot.py | 11 +++++++++++ 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ed8ce95..f92ead0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## dbt-databricks 1.9.5 (TBD) +### Fixes + +- table_format: iceberg is unblocked for snapshots ([930](https://github.com/databricks/dbt-databricks/pull/930)) + ### Under the Hood - Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index c30015cd..26f677e3 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -194,9 +194,10 @@ def update_tblproperties_for_iceberg( raise DbtConfigError( "When table_format is 'iceberg', cannot set file_format to other than delta." ) - if config.get("materialized") not in ("incremental", "table"): + if config.get("materialized") not in ("incremental", "table", "snapshot"): raise DbtConfigError( - "When table_format is 'iceberg', materialized must be 'incremental' or 'table'." + "When table_format is 'iceberg', materialized must be 'incremental'" + ", 'table', or 'snapshot'." ) result["delta.enableIcebergCompatV2"] = "true" result["delta.universalFormat.enabledFormats"] = "iceberg" diff --git a/tests/functional/adapter/iceberg/test_iceberg_support.py b/tests/functional/adapter/iceberg/test_iceberg_support.py index b25a54af..a52f4b97 100644 --- a/tests/functional/adapter/iceberg/test_iceberg_support.py +++ b/tests/functional/adapter/iceberg/test_iceberg_support.py @@ -5,8 +5,7 @@ from tests.functional.adapter.iceberg import fixtures -# @pytest.mark.skip_profile("databricks_cluster") -@pytest.mark.skip("Skip for now as it is broken in prod") +@pytest.mark.skip_profile("databricks_cluster") class TestIcebergTables: @pytest.fixture(scope="class") def models(self): @@ -21,8 +20,7 @@ def test_iceberg_refs(self, project): assert len(run_results) == 3 -# @pytest.mark.skip_profile("databricks_cluster") -@pytest.mark.skip("Skip for now as it is broken in prod") +@pytest.mark.skip_profile("databricks_cluster") class TestIcebergSwap: @pytest.fixture(scope="class") def models(self): diff --git a/tests/functional/adapter/simple_snapshot/test_snapshot.py b/tests/functional/adapter/simple_snapshot/test_snapshot.py index 56186e4e..a77e9bef 100644 --- a/tests/functional/adapter/simple_snapshot/test_snapshot.py +++ b/tests/functional/adapter/simple_snapshot/test_snapshot.py @@ -32,6 +32,17 @@ class TestSnapshotCheck(BaseSnapshotCheck): pass +@pytest.mark.skip_profile("databricks_cluster") +class TestSnapshotIceberg(BaseSnapshotCheck): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "snapshots": { + "+table_format": "iceberg", + } + } + + class TestSnapshotPersistDocs: @pytest.fixture(scope="class") def models(self): From 4b1d2d99b027ec3d88e8ef6ab5e89ba7a83fb15b Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Thu, 6 Feb 2025 12:50:34 -0800 Subject: [PATCH 16/21] Fix #933 (#934) --- CHANGELOG.md | 1 + dbt/adapters/databricks/impl.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f92ead0f..74ec15a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ### Fixes - table_format: iceberg is unblocked for snapshots ([930](https://github.com/databricks/dbt-databricks/pull/930)) +- Fix for regression in glue table listing behavior ([934](https://github.com/databricks/dbt-databricks/pull/934)) ### Under the Hood diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 26f677e3..e9b4fce4 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -140,7 +140,7 @@ def get_identifier_list_string(table_names: set[str]) -> str: _identifier = "|".join(table_names) bypass_2048_char_limit = GlobalState.get_char_limit_bypass() - if bypass_2048_char_limit == "true": + if bypass_2048_char_limit: _identifier = _identifier if len(_identifier) < 2048 else "*" return _identifier From f6a58ee4227566f5b2b1498de8d7e4e2518c98ef Mon Sep 17 00:00:00 2001 From: ShaneMazur <30323745+ShaneMazur@users.noreply.github.com> Date: Mon, 10 Feb 2025 13:05:44 -0500 Subject: [PATCH 17/21] Allow auto liquid clustering (#935) --- CHANGELOG.md | 5 +++++ dbt/adapters/databricks/impl.py | 1 + .../macros/relations/liquid_clustering.sql | 8 ++++++++ .../databricks/macros/relations/optimize.sql | 6 +++--- .../adapter/liquid_clustering/fixtures.py | 5 +++++ .../liquid_clustering/test_liquid_clustering.py | 13 +++++++++++++ tests/unit/macros/relations/test_table_macros.py | 10 ++++++++++ 7 files changed, 45 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 74ec15a6..c1c2d60b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,15 @@ ## dbt-databricks 1.9.5 (TBD) +### Features + +- Add `auto_liquid_cluster` config to enable Auto Liquid Clustering for Delta-based dbt models ([935](https://github.com/databricks/dbt-databricks/pull/935)) + ### Fixes - table_format: iceberg is unblocked for snapshots ([930](https://github.com/databricks/dbt-databricks/pull/930)) - Fix for regression in glue table listing behavior ([934](https://github.com/databricks/dbt-databricks/pull/934)) + ### Under the Hood - Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index e9b4fce4..24dd4d27 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -110,6 +110,7 @@ class DatabricksConfig(AdapterConfig): partition_by: Optional[Union[list[str], str]] = None clustered_by: Optional[Union[list[str], str]] = None liquid_clustered_by: Optional[Union[list[str], str]] = None + auto_liquid_cluster: Optional[bool] = None buckets: Optional[int] = None options: Optional[dict[str, str]] = None merge_update_columns: Optional[str] = None diff --git a/dbt/include/databricks/macros/relations/liquid_clustering.sql b/dbt/include/databricks/macros/relations/liquid_clustering.sql index b30269fd..43a3b113 100644 --- a/dbt/include/databricks/macros/relations/liquid_clustering.sql +++ b/dbt/include/databricks/macros/relations/liquid_clustering.sql @@ -1,15 +1,19 @@ {% macro liquid_clustered_cols() -%} {%- set cols = config.get('liquid_clustered_by', validator=validation.any[list, basestring]) -%} + {%- set auto_cluster = config.get('auto_liquid_cluster', validator=validation.any[boolean]) -%} {%- if cols is not none %} {%- if cols is string -%} {%- set cols = [cols] -%} {%- endif -%} CLUSTER BY ({{ cols | join(', ') }}) + {%- elif auto_cluster -%} + CLUSTER BY AUTO {%- endif %} {%- endmacro -%} {% macro apply_liquid_clustered_cols(target_relation) -%} {%- set cols = config.get('liquid_clustered_by', validator=validation.any[list, basestring]) -%} + {%- set auto_cluster = config.get('auto_liquid_cluster', validator=validation.any[boolean]) -%} {%- if cols is not none %} {%- if cols is string -%} {%- set cols = [cols] -%} @@ -17,5 +21,9 @@ {%- call statement('set_cluster_by_columns') -%} ALTER {{ target_relation.type }} {{ target_relation.render() }} CLUSTER BY ({{ cols | join(', ') }}) {%- endcall -%} + {%- elif auto_cluster -%} + {%- call statement('set_cluster_by_auto') -%} + ALTER {{ target_relation.type }} {{ target_relation.render() }} CLUSTER BY AUTO + {%- endcall -%} {%- endif %} {%- endmacro -%} \ No newline at end of file diff --git a/dbt/include/databricks/macros/relations/optimize.sql b/dbt/include/databricks/macros/relations/optimize.sql index 79f00164..a6108709 100644 --- a/dbt/include/databricks/macros/relations/optimize.sql +++ b/dbt/include/databricks/macros/relations/optimize.sql @@ -6,7 +6,7 @@ {%- if var('DATABRICKS_SKIP_OPTIMIZE', 'false')|lower != 'true' and var('databricks_skip_optimize', 'false')|lower != 'true' and config.get('file_format', 'delta') == 'delta' -%} - {%- if (config.get('zorder', False) or config.get('liquid_clustered_by', False)) -%} + {%- if (config.get('zorder', False) or config.get('liquid_clustered_by', False)) or config.get('auto_liquid_cluster', False) -%} {%- call statement('run_optimize_stmt') -%} {{ get_optimize_sql(relation) }} {%- endcall -%} @@ -17,8 +17,8 @@ {%- macro get_optimize_sql(relation) %} optimize {{ relation }} {%- if config.get('zorder', False) and config.get('file_format', 'delta') == 'delta' %} - {%- if config.get('liquid_clustered_by', False) %} - {{ exceptions.warn("Both zorder and liquid_clustered_by are set but they are incompatible. zorder will be ignored.") }} + {%- if config.get('liquid_clustered_by', False) or config.get('auto_liquid_cluster', False) %} + {{ exceptions.warn("Both zorder and liquid_clustering are set but they are incompatible. zorder will be ignored.") }} {%- else %} {%- set zorder = config.get('zorder', none) %} {# TODO: predicates here? WHERE ... #} diff --git a/tests/functional/adapter/liquid_clustering/fixtures.py b/tests/functional/adapter/liquid_clustering/fixtures.py index 951210ea..6c730222 100644 --- a/tests/functional/adapter/liquid_clustering/fixtures.py +++ b/tests/functional/adapter/liquid_clustering/fixtures.py @@ -2,3 +2,8 @@ {{ config(materialized='incremental', liquid_clustered_by='id') }} select 1 as id, 'Joe' as name """ + +auto_liquid_cluster_sql = """ +{{ config(materialized='incremental', auto_liquid_cluster=true) }} +select 1 as id, 'Joe' as name +""" diff --git a/tests/functional/adapter/liquid_clustering/test_liquid_clustering.py b/tests/functional/adapter/liquid_clustering/test_liquid_clustering.py index a9cc0ee0..45c7dfe2 100644 --- a/tests/functional/adapter/liquid_clustering/test_liquid_clustering.py +++ b/tests/functional/adapter/liquid_clustering/test_liquid_clustering.py @@ -15,3 +15,16 @@ def models(self): def test_liquid_clustering(self, project): _, logs = util.run_dbt_and_capture(["--debug", "run"]) assert "optimize" in logs + + +class TestAutoLiquidClustering: + @pytest.fixture(scope="class") + def models(self): + return { + "liquid_clustering.sql": fixtures.liquid_cluster_sql, + } + + @pytest.mark.skip_profile("databricks_uc_cluster", "databricks_cluster") + def test_liquid_clustering(self, project): + _, logs = util.run_dbt_and_capture(["--debug", "run"]) + assert "optimize" in logs diff --git a/tests/unit/macros/relations/test_table_macros.py b/tests/unit/macros/relations/test_table_macros.py index 9c563513..f05dc9d4 100644 --- a/tests/unit/macros/relations/test_table_macros.py +++ b/tests/unit/macros/relations/test_table_macros.py @@ -180,6 +180,16 @@ def test_macros_create_table_as_liquid_clusters(self, config, template_bundle): assert sql == expected + def test_macros_create_table_as_liquid_cluster_auto(self, config, template_bundle): + config["auto_liquid_cluster"] = True + sql = self.render_create_table_as(template_bundle) + expected = ( + f"create or replace table {template_bundle.relation} using" + " delta CLUSTER BY AUTO as select 1" + ) + + assert sql == expected + def test_macros_create_table_as_comment(self, config, template_bundle): config["persist_docs"] = {"relation": True} template_bundle.context["model"].description = "Description Test" From fbb3b7e1ae4b4cc9e3ca3cbacb12c4fd7b968b0a Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Tue, 11 Feb 2025 12:02:34 -0800 Subject: [PATCH 18/21] Prepare for environments with serverless notebooks (#938) --- CHANGELOG.md | 2 +- dbt/adapters/databricks/api_client.py | 4 +++ .../databricks/python_models/python_config.py | 2 ++ .../python_models/python_submissions.py | 15 ++++++++++- .../adapter/python_model/fixtures.py | 26 +++++++++++++++++++ .../adapter/python_model/test_python_model.py | 15 +++++++++++ tests/unit/python/test_python_job_support.py | 23 ++++++++++++++-- 7 files changed, 83 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c1c2d60b..40570724 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,13 +3,13 @@ ### Features - Add `auto_liquid_cluster` config to enable Auto Liquid Clustering for Delta-based dbt models ([935](https://github.com/databricks/dbt-databricks/pull/935)) +- Prepare for environments for python models with serverless clusters ([938](https://github.com/databricks/dbt-databricks/pull/938)) ### Fixes - table_format: iceberg is unblocked for snapshots ([930](https://github.com/databricks/dbt-databricks/pull/930)) - Fix for regression in glue table listing behavior ([934](https://github.com/databricks/dbt-databricks/pull/934)) - ### Under the Hood - Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) diff --git a/dbt/adapters/databricks/api_client.py b/dbt/adapters/databricks/api_client.py index 57bc4ca3..251da5a6 100644 --- a/dbt/adapters/databricks/api_client.py +++ b/dbt/adapters/databricks/api_client.py @@ -326,6 +326,10 @@ def __init__(self, session: Session, host: str, polling_interval: int, timeout: def submit( self, run_name: str, job_spec: dict[str, Any], **additional_job_settings: dict[str, Any] ) -> str: + logger.debug( + f"Submitting job with run_name={run_name} and job_spec={job_spec}" + " and additional_job_settings={additional_job_settings}" + ) submit_response = self.session.post( "/submit", json={"run_name": run_name, "tasks": [job_spec], **additional_job_settings} ) diff --git a/dbt/adapters/databricks/python_models/python_config.py b/dbt/adapters/databricks/python_models/python_config.py index 29aa44ef..5a39a5c9 100644 --- a/dbt/adapters/databricks/python_models/python_config.py +++ b/dbt/adapters/databricks/python_models/python_config.py @@ -36,6 +36,8 @@ class PythonModelConfig(BaseModel): cluster_id: Optional[str] = None http_path: Optional[str] = None create_notebook: bool = False + environment_key: Optional[str] = None + environment_dependencies: list[str] = Field(default_factory=list) class ParsedPythonModel(BaseModel): diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index afcb383c..f98db234 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -209,6 +209,8 @@ def __init__( self.job_grants = parsed_model.config.python_job_config.grants self.acls = parsed_model.config.access_control_list self.additional_job_settings = parsed_model.config.python_job_config.dict() + self.environment_key = parsed_model.config.environment_key + self.environment_deps = parsed_model.config.environment_dependencies def compile(self, path: str) -> PythonJobDetails: job_spec: dict[str, Any] = { @@ -217,9 +219,20 @@ def compile(self, path: str) -> PythonJobDetails: "notebook_path": path, }, } - job_spec.update(self.cluster_spec) # updates 'new_cluster' config additional_job_config = self.additional_job_settings + + if self.environment_key: + job_spec["environment_key"] = self.environment_key + if self.environment_deps and not self.additional_job_settings.get("environments"): + additional_job_config["environments"] = [ + { + "environment_key": self.environment_key, + "spec": {"client": "2", "dependencies": self.environment_deps}, + } + ] + job_spec.update(self.cluster_spec) # updates 'new_cluster' config + access_control_list = self.permission_builder.build_job_permissions( self.job_grants, self.acls ) diff --git a/tests/functional/adapter/python_model/fixtures.py b/tests/functional/adapter/python_model/fixtures.py index 127bcf74..d1a4dd4a 100644 --- a/tests/functional/adapter/python_model/fixtures.py +++ b/tests/functional/adapter/python_model/fixtures.py @@ -42,6 +42,32 @@ def model(dbt, spark): identifier: source """ +serverless_schema_with_environment = """version: 2 + +models: + - name: my_versioned_sql_model + versions: + - v: 1 + - name: my_python_model + config: + submission_method: serverless_cluster + create_notebook: true + environment_key: "test_key" + environment_dependencies: ["requests"] + +sources: + - name: test_source + loader: custom + schema: "{{ var(env_var('DBT_TEST_SCHEMA_NAME_VARIABLE')) }}" + quoting: + identifier: True + tags: + - my_test_source_tag + tables: + - name: test_table + identifier: source +""" + workflow_schema = """version: 2 models: diff --git a/tests/functional/adapter/python_model/test_python_model.py b/tests/functional/adapter/python_model/test_python_model.py index 726791df..ce490c0f 100644 --- a/tests/functional/adapter/python_model/test_python_model.py +++ b/tests/functional/adapter/python_model/test_python_model.py @@ -113,6 +113,21 @@ def models(self): } +@pytest.mark.python +# @pytest.mark.skip_profile("databricks_cluster", "databricks_uc_cluster") +@pytest.mark.skip("Not available in Databricks yet") +class TestServerlessClusterWithEnvironment(BasePythonModelTests): + @pytest.fixture(scope="class") + def models(self): + return { + "schema.yml": override_fixtures.serverless_schema_with_environment, + "my_sql_model.sql": fixtures.basic_sql, + "my_versioned_sql_model_v1.sql": fixtures.basic_sql, + "my_python_model.py": fixtures.basic_python, + "second_sql_model.sql": fixtures.second_sql, + } + + @pytest.mark.python @pytest.mark.external @pytest.mark.skip_profile("databricks_cluster", "databricks_uc_sql_endpoint") diff --git a/tests/unit/python/test_python_job_support.py b/tests/unit/python/test_python_job_support.py index 41f48041..fb996efa 100644 --- a/tests/unit/python/test_python_job_support.py +++ b/tests/unit/python/test_python_job_support.py @@ -146,8 +146,16 @@ def run_name(self, parsed_model): parsed_model.config.additional_libs = [] return run_name + @pytest.fixture + def environment_key(self, parsed_model): + environment_key = "test_key" + parsed_model.config.environment_key = environment_key + parsed_model.config.environment_dependencies = ["requests"] + return environment_key + def test_compile__empty_configs(self, client, permission_builder, parsed_model, run_name): parsed_model.config.python_job_config.dict.return_value = {} + parsed_model.config.environment_key = None compiler = PythonJobConfigCompiler(client, permission_builder, parsed_model, {}) permission_builder.build_job_permissions.return_value = [] details = compiler.compile("path") @@ -162,7 +170,9 @@ def test_compile__empty_configs(self, client, permission_builder, parsed_model, } assert details.additional_job_config == {} - def test_compile__nonempty_configs(self, client, permission_builder, parsed_model, run_name): + def test_compile__nonempty_configs( + self, client, permission_builder, parsed_model, run_name, environment_key + ): parsed_model.config.packages = ["foo"] parsed_model.config.index_url = None parsed_model.config.python_job_config.dict.return_value = {"foo": "bar"} @@ -176,6 +186,7 @@ def test_compile__nonempty_configs(self, client, permission_builder, parsed_mode details = compiler.compile("path") assert details.run_name == run_name assert details.job_spec == { + "environment_key": environment_key, "task_key": "inner_notebook", "notebook_task": { "notebook_path": "path", @@ -185,4 +196,12 @@ def test_compile__nonempty_configs(self, client, permission_builder, parsed_mode "access_control_list": [{"user_name": "user", "permission_level": "IS_OWNER"}], "queue": {"enabled": True}, } - assert details.additional_job_config == {"foo": "bar"} + assert details.additional_job_config == { + "foo": "bar", + "environments": [ + { + "environment_key": environment_key, + "spec": {"client": "2", "dependencies": ["requests"]}, + } + ], + } From 40c23374210f814334bafe59ec03e5bf18c5d86b Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Tue, 11 Feb 2025 14:23:38 -0800 Subject: [PATCH 19/21] Prep for 1.9.5 release (#937) --- CHANGELOG.md | 11 ++++++----- dbt/adapters/databricks/__version__.py | 2 +- pyproject.toml | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 40570724..fc09f063 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,22 +1,23 @@ -## dbt-databricks 1.9.5 (TBD) +## dbt-databricks 1.9.5 (Feb 11, 2025) ### Features -- Add `auto_liquid_cluster` config to enable Auto Liquid Clustering for Delta-based dbt models ([935](https://github.com/databricks/dbt-databricks/pull/935)) +- Add `auto_liquid_cluster` config to enable Auto Liquid Clustering for Delta-based dbt models (thanks @ShaneMazur!) ([935](https://github.com/databricks/dbt-databricks/pull/935)) - Prepare for environments for python models with serverless clusters ([938](https://github.com/databricks/dbt-databricks/pull/938)) ### Fixes - table_format: iceberg is unblocked for snapshots ([930](https://github.com/databricks/dbt-databricks/pull/930)) - Fix for regression in glue table listing behavior ([934](https://github.com/databricks/dbt-databricks/pull/934)) +- Use POSIX standard when creating location for the tables (thanks @gsolasab!) ([919](https://github.com/databricks/dbt-databricks/pull/919)) + ### Under the Hood - Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) -- Use POSIX standard when creating location for the tables ([919](https://github.com/databricks/dbt-databricks/pull/919)) - Clean up cursor management in the hopes of limiting issues with cancellation ([912](https://github.com/databricks/dbt-databricks/pull/912)) -## dbt-databricks 1.9.4 (Jan 30, 2024) +## dbt-databricks 1.9.4 (Jan 30, 2025) ### Under the Hood @@ -26,7 +27,7 @@ Yanked due to being published with the incorrect bits -## dbt-databricks 1.9.2 (Jan 21, 2024) +## dbt-databricks 1.9.2 (Jan 21, 2025) ### Features diff --git a/dbt/adapters/databricks/__version__.py b/dbt/adapters/databricks/__version__.py index 53988968..14e6fa9f 100644 --- a/dbt/adapters/databricks/__version__.py +++ b/dbt/adapters/databricks/__version__.py @@ -1 +1 @@ -version = "1.9.4" +version = "1.9.5" diff --git a/pyproject.toml b/pyproject.toml index 7c8c0e82..c176f1ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ ] dependencies = [ "databricks-sdk==0.17.0", - "databricks-sql-connector>=3.5.0, <3.7.0", + "databricks-sql-connector>=3.5.0, <4.0.0", "dbt-adapters>=1.7.0, <2.0", "dbt-common>=1.10.0, <2.0", "dbt-core>=1.8.7, <2.0", From 2ccd72930d32a887bba443bd6d0f895c5a01f94b Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Thu, 13 Feb 2025 16:03:26 -0800 Subject: [PATCH 20/21] Scope connection parameter creation (#929) --- CHANGELOG.md | 9 +- dbt/adapters/databricks/connections.py | 142 +++++++-------- tests/unit/test_compute_config.py | 59 ------ tests/unit/test_idle_config.py | 240 ------------------------- tests/unit/test_query_config.py | 93 ++++++++++ 5 files changed, 172 insertions(+), 371 deletions(-) delete mode 100644 tests/unit/test_compute_config.py delete mode 100644 tests/unit/test_idle_config.py create mode 100644 tests/unit/test_query_config.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fc09f063..d6cefd9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,10 @@ -## dbt-databricks 1.9.5 (Feb 11, 2025) +## dbt-databricks 1.9.6 (TBD) + +### Under the Hood + +- Refactoring of some connection internals ([929](https://github.com/databricks/dbt-databricks/pull/929)) + +## dbt-databricks 1.9.5 (Feb 13, 2025) ### Features @@ -11,7 +17,6 @@ - Fix for regression in glue table listing behavior ([934](https://github.com/databricks/dbt-databricks/pull/934)) - Use POSIX standard when creating location for the tables (thanks @gsolasab!) ([919](https://github.com/databricks/dbt-databricks/pull/919)) - ### Under the Hood - Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index b536fda6..56bf196c 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -94,6 +94,26 @@ """ +@dataclass(frozen=True) +class QueryContextWrapper: + """ + Until dbt tightens this protocol up, we need to wrap the context for safety + """ + + compute_name: Optional[str] = None + relation_name: Optional[str] = None + + @staticmethod + def from_context(query_header_context: Any) -> "QueryContextWrapper": + if query_header_context is None: + return QueryContextWrapper() + compute_name = None + relation_name = getattr(query_header_context, "relation_name", "[unknown]") + if hasattr(query_header_context, "config") and query_header_context.config: + compute_name = query_header_context.config.get("databricks_compute") + return QueryContextWrapper(compute_name=compute_name, relation_name=relation_name) + + class DatabricksMacroQueryStringSetter(MacroQueryStringSetter): def _get_comment_macro(self) -> Optional[str]: if self.config.query_comment.comment == DEFAULT_QUERY_COMMENT: @@ -238,16 +258,16 @@ def set_connection_name( self._cleanup_idle_connections() conn_name: str = "master" if name is None else name - + wrapped = QueryContextWrapper.from_context(query_header_context) # Get a connection for this thread - conn = self._get_if_exists_compute_connection(_get_compute_name(query_header_context) or "") + conn = self._get_if_exists_compute_connection(wrapped.compute_name or "") if conn is None: - conn = self._create_compute_connection(conn_name, query_header_context) + conn = self._create_compute_connection(conn_name, wrapped) else: # existing connection either wasn't open or didn't have the right name conn = self._update_compute_connection(conn, conn_name) - conn._acquire(query_header_context) + conn._acquire(wrapped) return conn @@ -518,13 +538,13 @@ def _cleanup_idle_connections(self) -> None: conn._reset_handle(self.open) def _create_compute_connection( - self, conn_name: str, query_header_context: Any = None + self, conn_name: str, query_header_context: QueryContextWrapper ) -> DatabricksDBTConnection: """Create anew connection for the combination of current thread and compute associated with the given node.""" # Create a new connection - compute_name = _get_compute_name(query_header_context) or "" + compute_name = query_header_context.compute_name or "" conn = DatabricksDBTConnection( type=Identifier(self.TYPE), @@ -536,9 +556,9 @@ def _create_compute_connection( ) conn.compute_name = compute_name creds = cast(DatabricksCredentials, self.profile.credentials) - conn.http_path = _get_http_path(query_header_context, creds=creds) or "" + conn.http_path = QueryConfigUtils.get_http_path(query_header_context, creds) conn.thread_identifier = cast(tuple[int, int], self.get_thread_identifier()) - conn.max_idle_time = _get_max_idle_time(query_header_context, creds=creds) + conn.max_idle_time = QueryConfigUtils.get_max_idle_time(query_header_context, creds) conn.handle = LazyHandle(self.open) @@ -604,74 +624,56 @@ def _update_compute_connection( return conn -def _get_compute_name(query_header_context: Any) -> Optional[str]: - # Get the name of the specified compute resource from the node's - # config. - compute_name = None - if ( - query_header_context - and hasattr(query_header_context, "config") - and query_header_context.config - ): - compute_name = query_header_context.config.get("databricks_compute", None) - return compute_name - - -def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> Optional[str]: - """Get the http_path for the compute specified for the node. - If none is specified default will be used.""" - - # ResultNode *should* have relation_name attr, but we work around a core - # issue by checking. - relation_name = getattr(query_header_context, "relation_name", "[unknown]") - - # If there is no node we return the http_path for the default compute. - if not query_header_context: - return creds.http_path - - # Get the name of the compute resource specified in the node's config. - # If none is specified return the http_path for the default compute. - compute_name = _get_compute_name(query_header_context) - if not compute_name: - return creds.http_path - - # Get the http_path for the named compute. - http_path = None - if creds.compute: - http_path = creds.compute.get(compute_name, {}).get("http_path", None) - - # no http_path for the named compute resource is an error condition - if not http_path: - raise DbtRuntimeError( - f"Compute resource {compute_name} does not exist or " - f"does not specify http_path, relation: {relation_name}" - ) +class QueryConfigUtils: + """ + Utility class for getting config values from QueryHeaderContextWrapper and Credentials. + """ - return http_path + @staticmethod + def get_http_path(context: QueryContextWrapper, creds: DatabricksCredentials) -> str: + """ + Get the http_path for the compute specified for the node. + If none is specified default will be used. + """ + if not context.compute_name: + return creds.http_path or "" + + # Get the http_path for the named compute. + http_path = None + if creds.compute: + http_path = creds.compute.get(context.compute_name, {}).get("http_path", None) + + # no http_path for the named compute resource is an error condition + if not http_path: + raise DbtRuntimeError( + f"Compute resource {context.compute_name} does not exist or " + f"does not specify http_path, relation: {context.relation_name}" + ) -def _get_max_idle_time(query_header_context: Any, creds: DatabricksCredentials) -> int: - """Get the http_path for the compute specified for the node. - If none is specified default will be used.""" + return http_path - max_idle_time = ( - DEFAULT_MAX_IDLE_TIME if creds.connect_max_idle is None else creds.connect_max_idle - ) + @staticmethod + def get_max_idle_time(context: QueryContextWrapper, creds: DatabricksCredentials) -> int: + """Get the http_path for the compute specified for the node. + If none is specified default will be used.""" - if query_header_context: - compute_name = _get_compute_name(query_header_context) - if compute_name and creds.compute: - max_idle_time = creds.compute.get(compute_name, {}).get( + max_idle_time = ( + DEFAULT_MAX_IDLE_TIME if creds.connect_max_idle is None else creds.connect_max_idle + ) + + if context.compute_name and creds.compute: + max_idle_time = creds.compute.get(context.compute_name, {}).get( "connect_max_idle", max_idle_time ) - if not isinstance(max_idle_time, int): - if isinstance(max_idle_time, str) and max_idle_time.strip().isnumeric(): - return int(max_idle_time.strip()) - else: - raise DbtRuntimeError( - f"{max_idle_time} is not a valid value for connect_max_idle. " - "Must be a number of seconds." - ) + if not isinstance(max_idle_time, int): + if isinstance(max_idle_time, str) and max_idle_time.strip().isnumeric(): + return int(max_idle_time.strip()) + else: + raise DbtRuntimeError( + f"{max_idle_time} is not a valid value for connect_max_idle. " + "Must be a number of seconds." + ) - return max_idle_time + return max_idle_time diff --git a/tests/unit/test_compute_config.py b/tests/unit/test_compute_config.py deleted file mode 100644 index 994d4ae9..00000000 --- a/tests/unit/test_compute_config.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import Mock - -import pytest -from dbt_common.exceptions import DbtRuntimeError - -from dbt.adapters.databricks import connections -from dbt.adapters.databricks.credentials import DatabricksCredentials - - -class TestDatabricksConnectionHTTPPath: - """Test the various cases for determining a specified warehouse.""" - - @pytest.fixture(scope="class") - def err_msg(self): - return ( - "Compute resource foo does not exist or does not specify http_path," - " relation: a_relation" - ) - - @pytest.fixture(scope="class") - def path(self): - return "my_http_path" - - @pytest.fixture - def creds(self, path): - return DatabricksCredentials(http_path=path) - - @pytest.fixture - def node(self): - n = Mock() - n.config = {} - n.relation_name = "a_relation" - return n - - def test_get_http_path__empty(self, path, creds): - assert connections._get_http_path(None, creds) == path - - def test_get_http_path__no_compute(self, node, path, creds): - assert connections._get_http_path(node, creds) == path - - def test_get_http_path__missing_compute(self, node, creds, err_msg): - node.config["databricks_compute"] = "foo" - with pytest.raises(DbtRuntimeError) as exc: - connections._get_http_path(node, creds) - - assert err_msg in str(exc.value) - - def test_get_http_path__empty_compute(self, node, creds, err_msg): - node.config["databricks_compute"] = "foo" - creds.compute = {"foo": {}} - with pytest.raises(DbtRuntimeError) as exc: - connections._get_http_path(node, creds) - - assert err_msg in str(exc.value) - - def test_get_http_path__matching_compute(self, node, creds): - node.config["databricks_compute"] = "foo" - creds.compute = {"foo": {"http_path": "alternate_path"}} - assert "alternate_path" == connections._get_http_path(node, creds) diff --git a/tests/unit/test_idle_config.py b/tests/unit/test_idle_config.py deleted file mode 100644 index 7115554c..00000000 --- a/tests/unit/test_idle_config.py +++ /dev/null @@ -1,240 +0,0 @@ -import pytest -from dbt_common.exceptions import DbtRuntimeError - -from dbt.adapters.databricks import connections -from dbt.adapters.databricks.credentials import DatabricksCredentials -from dbt.contracts.graph import model_config, nodes - - -class TestDatabricksConnectionMaxIdleTime: - """Test the various cases for determining a specified warehouse.""" - - errMsg = ( - "Compute resource foo does not exist or does not specify http_path, relation: a_relation" - ) - - def test_get_max_idle_default(self): - creds = DatabricksCredentials() - - # No node and nothing specified in creds - time = connections._get_max_idle_time(None, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - node = nodes.ModelNode( - relation_name="a_relation", - database="database", - schema="schema", - name="node_name", - resource_type="model", - package_name="package", - path="path", - original_file_path="orig_path", - unique_id="uniqueID", - fqn=[], - alias="alias", - checksum=None, - ) - - # node has no configuration so should get back default - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - # empty configuration should return default - node.config = model_config.ModelConfig() - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - # node with no extras in configuration should return default - node.config._extra = {} - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - # node that specifies a compute with no corresponding definition should return default - node.config._extra["databricks_compute"] = "foo" - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - creds.compute = {} - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - # if alternate compute doesn't specify a max time should return default - creds.compute = {"foo": {}} - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - # with self.assertRaisesRegex( - # dbt.exceptions.DbtRuntimeError, - # self.errMsg, - # ): - # connections._get_http_path(node, creds) - - # creds.compute = {"foo": {"http_path": "alternate_path"}} - # path = connections._get_http_path(node, creds) - # self.assertEqual("alternate_path", path) - - def test_get_max_idle_creds(self): - creds_idle_time = 77 - creds = DatabricksCredentials(connect_max_idle=creds_idle_time) - - # No node so value should come from creds - time = connections._get_max_idle_time(None, creds) - assert creds_idle_time == time - - node = nodes.ModelNode( - relation_name="a_relation", - database="database", - schema="schema", - name="node_name", - resource_type="model", - package_name="package", - path="path", - original_file_path="orig_path", - unique_id="uniqueID", - fqn=[], - alias="alias", - checksum=None, - ) - - # node has no configuration so should get value from creds - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - # empty configuration should get value from creds - node.config = model_config.ModelConfig() - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - # node with no extras in configuration should get value from creds - node.config._extra = {} - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - # node that specifies a compute with no corresponding definition should get value from creds - node.config._extra["databricks_compute"] = "foo" - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - creds.compute = {} - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - # if alternate compute doesn't specify a max time should get value from creds - creds.compute = {"foo": {}} - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - def test_get_max_idle_compute(self): - creds_idle_time = 88 - compute_idle_time = 77 - creds = DatabricksCredentials(connect_max_idle=creds_idle_time) - creds.compute = {"foo": {"connect_max_idle": compute_idle_time}} - - node = nodes.SnapshotNode( - config=None, - relation_name="a_relation", - database="database", - schema="schema", - name="node_name", - resource_type="model", - package_name="package", - path="path", - original_file_path="orig_path", - unique_id="uniqueID", - fqn=[], - alias="alias", - checksum=None, - ) - - node.config = model_config.SnapshotConfig() - node.config._extra = {"databricks_compute": "foo"} - - time = connections._get_max_idle_time(node, creds) - assert compute_idle_time == time - - def test_get_max_idle_invalid(self): - creds_idle_time = "foo" - compute_idle_time = "bar" - creds = DatabricksCredentials(connect_max_idle=creds_idle_time) - creds.compute = {"alternate_compute": {"connect_max_idle": compute_idle_time}} - - node = nodes.SnapshotNode( - config=None, - relation_name="a_relation", - database="database", - schema="schema", - name="node_name", - resource_type="model", - package_name="package", - path="path", - original_file_path="orig_path", - unique_id="uniqueID", - fqn=[], - alias="alias", - checksum=None, - ) - - node.config = model_config.SnapshotConfig() - - with pytest.raises(DbtRuntimeError) as info: - connections._get_max_idle_time(node, creds) - assert ( - f"{creds_idle_time} is not a valid value for connect_max_idle. " - "Must be a number of seconds." - ) in str(info.value) - - node.config._extra["databricks_compute"] = "alternate_compute" - with pytest.raises(DbtRuntimeError) as info: - connections._get_max_idle_time(node, creds) - assert ( - f"{compute_idle_time} is not a valid value for connect_max_idle. " - "Must be a number of seconds." - ) in str(info.value) - - creds.compute["alternate_compute"]["connect_max_idle"] = "1.2.3" - with pytest.raises(DbtRuntimeError) as info: - connections._get_max_idle_time(node, creds) - assert ( - "1.2.3 is not a valid value for connect_max_idle. Must be a number of seconds." - ) in str(info.value) - - creds.compute["alternate_compute"]["connect_max_idle"] = "1,002.3" - with pytest.raises(DbtRuntimeError) as info: - connections._get_max_idle_time(node, creds) - assert ( - "1,002.3 is not a valid value for connect_max_idle. Must be a number of seconds." - ) in str(info.value) - - def test_get_max_idle_simple_string_conversion(self): - creds_idle_time = "12" - compute_idle_time = "34" - creds = DatabricksCredentials(connect_max_idle=creds_idle_time) - creds.compute = {"alternate_compute": {"connect_max_idle": compute_idle_time}} - - node = nodes.SnapshotNode( - config=None, - relation_name="a_relation", - database="database", - schema="schema", - name="node_name", - resource_type="model", - package_name="package", - path="path", - original_file_path="orig_path", - unique_id="uniqueID", - fqn=[], - alias="alias", - checksum=None, - ) - - node.config = model_config.SnapshotConfig() - - time = connections._get_max_idle_time(node, creds) - assert float(creds_idle_time) == time - - node.config._extra["databricks_compute"] = "alternate_compute" - time = connections._get_max_idle_time(node, creds) - assert float(compute_idle_time) == time - - creds.compute["alternate_compute"]["connect_max_idle"] = " 56 " - time = connections._get_max_idle_time(node, creds) - assert 56 == time diff --git a/tests/unit/test_query_config.py b/tests/unit/test_query_config.py new file mode 100644 index 00000000..61813200 --- /dev/null +++ b/tests/unit/test_query_config.py @@ -0,0 +1,93 @@ +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.databricks import connections +from dbt.adapters.databricks.connections import QueryConfigUtils, QueryContextWrapper +from dbt.adapters.databricks.credentials import DatabricksCredentials + + +class TestQueryConfigUtils: + """Test the various cases for determining a specified warehouse.""" + + @pytest.fixture(scope="class") + def err_msg(self): + return ( + "Compute resource foo does not exist or does not specify http_path," + " relation: a_relation" + ) + + @pytest.fixture(scope="class") + def path(self): + return "my_http_path" + + @pytest.fixture + def creds(self, path): + return DatabricksCredentials(http_path=path) + + def test_get_http_path__empty(self, path, creds): + assert QueryConfigUtils.get_http_path(QueryContextWrapper(), creds) == path + + def test_get_http_path__no_compute(self, path, creds): + assert ( + QueryConfigUtils.get_http_path(QueryContextWrapper(relation_name="a_relation"), creds) + == path + ) + + def test_get_http_path__missing_compute(self, creds, err_msg): + context = QueryContextWrapper(compute_name="foo", relation_name="a_relation") + with pytest.raises(DbtRuntimeError) as exc: + QueryConfigUtils.get_http_path(context, creds) + + assert err_msg in str(exc.value) + + def test_get_http_path__empty_compute(self, creds, err_msg): + context = QueryContextWrapper(compute_name="foo", relation_name="a_relation") + creds.compute = {"foo": {}} + with pytest.raises(DbtRuntimeError) as exc: + QueryConfigUtils.get_http_path(context, creds) + + assert err_msg in str(exc.value) + + def test_get_http_path__matching_compute(self, creds): + context = QueryContextWrapper(compute_name="foo", relation_name="a_relation") + creds.compute = {"foo": {"http_path": "alternate_path"}} + assert "alternate_path" == QueryConfigUtils.get_http_path(context, creds) + + def test_get_max_idle__no_config(self, creds): + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(), creds) + assert connections.DEFAULT_MAX_IDLE_TIME == time + + def test_get_max_idle__no_matching_compute(self, creds): + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(compute_name="foo"), creds) + assert connections.DEFAULT_MAX_IDLE_TIME == time + + def test_get_max_idle__compute_without_details(self, creds): + creds.compute = {"foo": {}} + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(compute_name="foo"), creds) + assert connections.DEFAULT_MAX_IDLE_TIME == time + + def test_get_max_idle__creds_but_no_context(self, creds): + creds.connect_max_idle = 77 + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(), creds) + assert 77 == time + + def test_get_max_idle__matching_compute_no_value(self, creds): + creds.connect_max_idle = 77 + creds.compute = {"foo": {}} + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(compute_name="foo"), creds) + assert 77 == time + + def test_get_max_idle__matching_compute(self, creds): + creds.compute = {"foo": {"connect_max_idle": "88"}} + creds.connect_max_idle = 77 + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(compute_name="foo"), creds) + assert 88 == time + + def test_get_max_idle__invalid_config(self, creds): + creds.compute = {"foo": {"connect_max_idle": "bar"}} + + with pytest.raises(DbtRuntimeError) as info: + QueryConfigUtils.get_max_idle_time(QueryContextWrapper(compute_name="foo"), creds) + assert ( + "bar is not a valid value for connect_max_idle. Must be a number of seconds." + ) in str(info.value) From 35bffecbecaad5397d4a556b37f4e5665b1c266f Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Wed, 19 Feb 2025 09:06:25 -0800 Subject: [PATCH 21/21] Lazy init api client for connection manager to fix 940 (#941) --- CHANGELOG.md | 4 ++++ dbt/adapters/databricks/connections.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d6cefd9c..a88c5413 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## dbt-databricks 1.9.6 (TBD) +### Fixes + +- Fix for parse raising error for not having credentials ([941](https://github.com/databricks/dbt-databricks/pull/941)) + ### Under the Hood - Refactoring of some connection internals ([929](https://github.com/databricks/dbt-databricks/pull/929)) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 56bf196c..ccf32c98 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -202,12 +202,19 @@ class DatabricksConnectionManager(SparkConnectionManager): def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): super().__init__(profile, mp_context) - creds = cast(DatabricksCredentials, self.profile.credentials) - self.api_client = DatabricksApiClient.create(creds, 15 * 60) + self._api_client: Optional[DatabricksApiClient] = None self.threads_compute_connections: dict[ Hashable, dict[Hashable, DatabricksDBTConnection] ] = {} + @property + def api_client(self) -> DatabricksApiClient: + if self._api_client is None: + self._api_client = DatabricksApiClient.create( + cast(DatabricksCredentials, self.profile.credentials), 15 * 60 + ) + return self._api_client + def cancel_open(self) -> list[str]: cancelled = super().cancel_open() logger.info("Cancelling open python jobs")