diff --git a/CHANGELOG.md b/CHANGELOG.md index a4c107a99..a88c5413f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,53 @@ +## dbt-databricks 1.9.6 (TBD) + +### Fixes + +- Fix for parse raising error for not having credentials ([941](https://github.com/databricks/dbt-databricks/pull/941)) + +### Under the Hood + +- Refactoring of some connection internals ([929](https://github.com/databricks/dbt-databricks/pull/929)) + +## dbt-databricks 1.9.5 (Feb 13, 2025) + +### Features + +- Add `auto_liquid_cluster` config to enable Auto Liquid Clustering for Delta-based dbt models (thanks @ShaneMazur!) ([935](https://github.com/databricks/dbt-databricks/pull/935)) +- Prepare for environments for python models with serverless clusters ([938](https://github.com/databricks/dbt-databricks/pull/938)) + +### Fixes + +- table_format: iceberg is unblocked for snapshots ([930](https://github.com/databricks/dbt-databricks/pull/930)) +- Fix for regression in glue table listing behavior ([934](https://github.com/databricks/dbt-databricks/pull/934)) +- Use POSIX standard when creating location for the tables (thanks @gsolasab!) ([919](https://github.com/databricks/dbt-databricks/pull/919)) + +### Under the Hood + +- Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) +- Clean up cursor management in the hopes of limiting issues with cancellation ([912](https://github.com/databricks/dbt-databricks/pull/912)) + +## dbt-databricks 1.9.4 (Jan 30, 2025) + +### Under the Hood + +- Pinned the python sql connector to 3.6.0 as a temporary measure while we investigate failure to wait for cluster start + +## dbt-databricks 1.9.3 + +Yanked due to being published with the incorrect bits + +## dbt-databricks 1.9.2 (Jan 21, 2025) + +### Features + +- Update snapshot materialization to support new snapshot features ([904](https://github.com/databricks/dbt-databricks/pull/904)) + +### Under the Hood + +- Refactor global state reading ([888](https://github.com/databricks/dbt-databricks/pull/888)) +- Switch to relation.render() for string interpolation ([903](https://github.com/databricks/dbt-databricks/pull/903)) +- Ensure retry defaults for PySQL ([907](https://github.com/databricks/dbt-databricks/pull/907)) + ## dbt-databricks 1.9.1 (December 16, 2024) ### Features diff --git a/dbt/adapters/databricks/__version__.py b/dbt/adapters/databricks/__version__.py index 702279763..14e6fa9f7 100644 --- a/dbt/adapters/databricks/__version__.py +++ b/dbt/adapters/databricks/__version__.py @@ -1 +1 @@ -version = "1.9.1" +version = "1.9.5" diff --git a/dbt/adapters/databricks/api_client.py b/dbt/adapters/databricks/api_client.py index 57bc4ca3c..251da5a63 100644 --- a/dbt/adapters/databricks/api_client.py +++ b/dbt/adapters/databricks/api_client.py @@ -326,6 +326,10 @@ def __init__(self, session: Session, host: str, polling_interval: int, timeout: def submit( self, run_name: str, job_spec: dict[str, Any], **additional_job_settings: dict[str, Any] ) -> str: + logger.debug( + f"Submitting job with run_name={run_name} and job_spec={job_spec}" + " and additional_job_settings={additional_job_settings}" + ) submit_response = self.session.post( "/submit", json={"run_name": run_name, "tasks": [job_spec], **additional_job_settings} ) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 509686d76..ccf32c981 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -1,16 +1,9 @@ -import decimal -import os import re -import sys import time -import uuid -import warnings -from collections.abc import Callable, Hashable, Iterator, Sequence +from collections.abc import Callable, Hashable, Iterator from contextlib import contextmanager from dataclasses import dataclass from multiprocessing.context import SpawnContext -from numbers import Number -from threading import get_ident from typing import TYPE_CHECKING, Any, Optional, cast from dbt_common.events.contextvars import get_node_info @@ -18,9 +11,7 @@ from dbt_common.exceptions import DbtDatabaseError, DbtInternalError, DbtRuntimeError from dbt_common.utils import cast_to_str -import databricks.sql as dbsql -from databricks.sql.client import Connection as DatabricksSQLConnection -from databricks.sql.client import Cursor as DatabricksSQLCursor +from databricks.sql import __version__ as dbsql_version from databricks.sql.exc import Error from dbt.adapters.base.query_headers import MacroQueryStringSetter from dbt.adapters.contracts.connection import ( @@ -34,15 +25,13 @@ ) from dbt.adapters.databricks.__version__ import version as __version__ from dbt.adapters.databricks.api_client import DatabricksApiClient -from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider +from dbt.adapters.databricks.credentials import ( + DatabricksCredentials, + TCredentialProvider, +) from dbt.adapters.databricks.events.connection_events import ( ConnectionAcquire, - ConnectionCancel, - ConnectionCancelError, - ConnectionClose, - ConnectionCloseError, ConnectionCreate, - ConnectionCreated, ConnectionCreateError, ConnectionIdleCheck, ConnectionIdleClose, @@ -51,14 +40,8 @@ ConnectionRetrieve, ConnectionReuse, ) -from dbt.adapters.databricks.events.cursor_events import ( - CursorCancel, - CursorCancelError, - CursorClose, - CursorCloseError, - CursorCreate, -) from dbt.adapters.databricks.events.other_events import QueryError +from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker from dbt.adapters.databricks.utils import redact_credentials @@ -83,197 +66,19 @@ ) -DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)") - - -# toggle for session managements that minimizes the number of sessions opened/closed -USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" - # Number of idle seconds before a connection is automatically closed. Only applicable if # USE_LONG_SESSIONS is true. # Updated when idle times of 180s were causing errors DEFAULT_MAX_IDLE_TIME = 60 -class DatabricksSQLConnectionWrapper: - """Wrap a Databricks SQL connector in a way that no-ops transactions""" - - _conn: DatabricksSQLConnection - _is_cluster: bool - _cursors: list[DatabricksSQLCursor] - _creds: DatabricksCredentials - _user_agent: str - - def __init__( - self, - conn: DatabricksSQLConnection, - *, - is_cluster: bool, - creds: DatabricksCredentials, - user_agent: str, - ): - self._conn = conn - self._is_cluster = is_cluster - self._cursors = [] - self._creds = creds - self._user_agent = user_agent - - def cursor(self) -> "DatabricksSQLCursorWrapper": - cursor = self._conn.cursor() - - logger.debug(CursorCreate(cursor)) - - self._cursors.append(cursor) - return DatabricksSQLCursorWrapper( - cursor, - creds=self._creds, - user_agent=self._user_agent, - ) - - def cancel(self) -> None: - logger.debug(ConnectionCancel(self._conn)) - - cursors: list[DatabricksSQLCursor] = self._cursors - - for cursor in cursors: - try: - cursor.cancel() - except Error as exc: - logger.warning(ConnectionCancelError(self._conn, exc)) - - def close(self) -> None: - logger.debug(ConnectionClose(self._conn)) - - try: - self._conn.close() - except Error as exc: - logger.warning(ConnectionCloseError(self._conn, exc)) - - def rollback(self, *args: Any, **kwargs: Any) -> None: - logger.debug("NotImplemented: rollback") - - _dbr_version: tuple[int, int] - - @property - def dbr_version(self) -> tuple[int, int]: - if not hasattr(self, "_dbr_version"): - if self._is_cluster: - with self._conn.cursor() as cursor: - cursor.execute("SET spark.databricks.clusterUsageTags.sparkVersion") - results = cursor.fetchone() - if results: - dbr_version: str = results[1] - - m = DBR_VERSION_REGEX.search(dbr_version) - assert m, f"Unknown DBR version: {dbr_version}" - major = int(m.group(1)) - try: - minor = int(m.group(2)) - except ValueError: - minor = sys.maxsize - self._dbr_version = (major, minor) - else: - # Assuming SQL Warehouse uses the latest version. - self._dbr_version = (sys.maxsize, sys.maxsize) - - return self._dbr_version - - -class DatabricksSQLCursorWrapper: - """Wrap a Databricks SQL cursor in a way that no-ops transactions""" - - _cursor: DatabricksSQLCursor - _user_agent: str - _creds: DatabricksCredentials - - def __init__(self, cursor: DatabricksSQLCursor, creds: DatabricksCredentials, user_agent: str): - self._cursor = cursor - self._creds = creds - self._user_agent = user_agent - - def cancel(self) -> None: - logger.debug(CursorCancel(self._cursor)) - - try: - self._cursor.cancel() - except Error as exc: - logger.warning(CursorCancelError(self._cursor, exc)) - - def close(self) -> None: - logger.debug(CursorClose(self._cursor)) - - try: - self._cursor.close() - except Error as exc: - logger.warning(CursorCloseError(self._cursor, exc)) - - def fetchall(self) -> Sequence[tuple]: - return self._cursor.fetchall() - - def fetchone(self) -> Optional[tuple]: - return self._cursor.fetchone() - - def fetchmany(self, size: int) -> Sequence[tuple]: - return self._cursor.fetchmany(size) - - def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None: - # print(f"execute: {sql}") - if sql.strip().endswith(";"): - sql = sql.strip()[:-1] - if bindings is not None: - bindings = [self._fix_binding(binding) for binding in bindings] - self._cursor.execute(sql, bindings) - - @property - def hex_query_id(self) -> str: - """Return the hex GUID for this query - - This UUID can be tied back to the Databricks query history API - """ - if self._cursor.active_result_set: - _as_hex = uuid.UUID(bytes=self._cursor.active_result_set.command_id.operationId.guid) - return str(_as_hex) - return "" - - @classmethod - def _fix_binding(cls, value: Any) -> Any: - """Convert complex datatypes to primitives that can be loaded by - the Spark driver""" - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - - @property - def description(self) -> Optional[list[tuple]]: - return self._cursor.description - - def schemas(self, catalog_name: str, schema_name: Optional[str] = None) -> None: - self._cursor.schemas(catalog_name=catalog_name, schema_name=schema_name) - - def tables(self, catalog_name: str, schema_name: str, table_name: Optional[str] = None) -> None: - self._cursor.tables( - catalog_name=catalog_name, schema_name=schema_name, table_name=table_name - ) - - def __del__(self) -> None: - if self._cursor.open: - # This should not happen. The cursor should explicitly be closed. - logger.debug(CursorClose(self._cursor)) - - self._cursor.close() - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn("The cursor was closed by destructor.") - - DATABRICKS_QUERY_COMMENT = f""" {{%- set comment_dict = {{}} -%}} {{%- do comment_dict.update( app='dbt', dbt_version=dbt_version, dbt_databricks_version='{__version__}', - databricks_sql_connector_version='{dbsql.__version__}', + databricks_sql_connector_version='{dbsql_version}', profile_name=target.get('profile_name'), target_name=target.get('target_name'), ) -%}} @@ -289,6 +94,26 @@ def __del__(self) -> None: """ +@dataclass(frozen=True) +class QueryContextWrapper: + """ + Until dbt tightens this protocol up, we need to wrap the context for safety + """ + + compute_name: Optional[str] = None + relation_name: Optional[str] = None + + @staticmethod + def from_context(query_header_context: Any) -> "QueryContextWrapper": + if query_header_context is None: + return QueryContextWrapper() + compute_name = None + relation_name = getattr(query_header_context, "relation_name", "[unknown]") + if hasattr(query_header_context, "config") and query_header_context.config: + compute_name = query_header_context.config.get("databricks_compute") + return QueryContextWrapper(compute_name=compute_name, relation_name=relation_name) + + class DatabricksMacroQueryStringSetter(MacroQueryStringSetter): def _get_comment_macro(self) -> Optional[str]: if self.config.query_comment.comment == DEFAULT_QUERY_COMMENT: @@ -297,11 +122,6 @@ def _get_comment_macro(self) -> Optional[str]: return self.config.query_comment.comment -@dataclass -class DatabricksAdapterResponse(AdapterResponse): - query_id: str = "" - - @dataclass(init=False) class DatabricksDBTConnection(Connection): last_used_time: Optional[float] = None @@ -379,12 +199,21 @@ def _reset_handle(self, open: Callable[[Connection], Connection]) -> None: class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" credentials_provider: Optional[TCredentialProvider] = None - _user_agent = f"dbt-databricks/{__version__}" def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): super().__init__(profile, mp_context) - creds = cast(DatabricksCredentials, self.profile.credentials) - self.api_client = DatabricksApiClient.create(creds, 15 * 60) + self._api_client: Optional[DatabricksApiClient] = None + self.threads_compute_connections: dict[ + Hashable, dict[Hashable, DatabricksDBTConnection] + ] = {} + + @property + def api_client(self) -> DatabricksApiClient: + if self._api_client is None: + self._api_client = DatabricksApiClient.create( + cast(DatabricksCredentials, self.profile.credentials), 15 * 60 + ) + return self._api_client def cancel_open(self) -> list[str]: cancelled = super().cancel_open() @@ -395,8 +224,8 @@ def cancel_open(self) -> list[str]: def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) - connection: DatabricksSQLConnectionWrapper = self.get_thread_connection().handle - dbr_version = connection.dbr_version + handle: DatabricksHandle = self.get_thread_connection().handle + dbr_version = handle.dbr_version return (dbr_version > version) - (dbr_version < version) def set_query_header(self, query_header_context: dict[str, Any]) -> None: @@ -433,39 +262,19 @@ def set_connection_name( 'connection_named', called by 'connection_for(node)'. Creates a connection for this thread if one doesn't already exist, and will rename an existing connection.""" + self._cleanup_idle_connections() conn_name: str = "master" if name is None else name - + wrapped = QueryContextWrapper.from_context(query_header_context) # Get a connection for this thread - conn = self.get_if_exists() - - if conn and conn.name == conn_name and conn.state == ConnectionState.OPEN: - # Found a connection and nothing to do, so just return it - return conn + conn = self._get_if_exists_compute_connection(wrapped.compute_name or "") if conn is None: - # Create a new connection - conn = DatabricksDBTConnection( - type=Identifier(self.TYPE), - name=conn_name, - state=ConnectionState.INIT, - transaction_open=False, - handle=None, - credentials=self.profile.credentials, - ) - conn.handle = LazyHandle(self.get_open_for_context(query_header_context)) - # Add the connection to thread_connections for this thread - self.set_thread_connection(conn) - fire_event( - NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) - ) + conn = self._create_compute_connection(conn_name, wrapped) else: # existing connection either wasn't open or didn't have the right name - if conn.state != ConnectionState.OPEN: - conn.handle = LazyHandle(self.get_open_for_context(query_header_context)) - if conn.name != conn_name: - orig_conn_name: str = conn.name or "" - conn.name = conn_name - fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name)) + conn = self._update_compute_connection(conn, conn_name) + + conn._acquire(wrapped) return conn @@ -475,6 +284,8 @@ def add_query( auto_begin: bool = True, bindings: Optional[Any] = None, abridge_sql_log: bool = False, + retryable_exceptions: tuple[type[Exception], ...] = tuple(), + retry_limit: int = 1, *, close_cursor: bool = False, ) -> tuple[Connection, Any]: @@ -484,7 +295,7 @@ def add_query( fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(sql): - cursor: Optional[DatabricksSQLCursorWrapper] = None + cursor: Optional[CursorWrapper] = None try: log_sql = redact_credentials(sql) if abridge_sql_log: @@ -500,12 +311,12 @@ def add_query( pre = time.time() - cursor = cast(DatabricksSQLConnectionWrapper, connection.handle).cursor() - cursor.execute(sql, bindings) + handle: DatabricksHandle = connection.handle + cursor = handle.execute(sql, bindings) fire_event( SQLQueryStatus( - status=str(self.get_response(cursor)), + status=str(cursor.get_response()), elapsed=round((time.time() - pre), 2), node_info=get_node_info(), ) @@ -513,9 +324,7 @@ def add_query( return connection, cursor except Error: - if cursor is not None: - cursor.close() - cursor = None + close_cursor = True raise finally: if close_cursor and cursor is not None: @@ -527,11 +336,11 @@ def execute( auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None, - ) -> tuple[DatabricksAdapterResponse, "Table"]: + ) -> tuple[AdapterResponse, "Table"]: sql = self._add_query_comment(sql) _, cursor = self.add_query(sql, auto_begin) try: - response = self.get_response(cursor) + response = cursor.get_response() if fetch: table = self.get_result_from_cursor(cursor, limit) else: @@ -543,15 +352,15 @@ def execute( finally: cursor.close() - def _execute_cursor( - self, log_sql: str, f: Callable[[DatabricksSQLCursorWrapper], None] + def _execute_with_cursor( + self, log_sql: str, f: Callable[[DatabricksHandle], CursorWrapper] ) -> "Table": connection = self.get_thread_connection() fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(log_sql): - cursor: Optional[DatabricksSQLCursorWrapper] = None + cursor: Optional[CursorWrapper] = None try: fire_event( SQLQuery( @@ -563,9 +372,8 @@ def _execute_cursor( pre = time.time() - handle: DatabricksSQLConnectionWrapper = connection.handle - cursor = handle.cursor() - f(cursor) + handle: DatabricksHandle = connection.handle + cursor = f(handle) fire_event( SQLQueryStatus( @@ -577,53 +385,58 @@ def _execute_cursor( return self.get_result_from_cursor(cursor, None) finally: - if cursor is not None: + if cursor: cursor.close() def list_schemas(self, database: str, schema: Optional[str] = None) -> "Table": database = database.strip("`") if schema: schema = schema.strip("`").lower() - return self._execute_cursor( + return self._execute_with_cursor( f"GetSchemas(database={database}, schema={schema})", - lambda cursor: cursor.schemas(catalog_name=database, schema_name=schema), + lambda cursor: cursor.list_schemas(database=database, schema=schema), ) - def list_tables(self, database: str, schema: str, identifier: Optional[str] = None) -> "Table": + def list_tables(self, database: str, schema: str) -> "Table": database = database.strip("`") schema = schema.strip("`").lower() - if identifier: - identifier = identifier.strip("`") - return self._execute_cursor( - f"GetTables(database={database}, schema={schema}, identifier={identifier})", - lambda cursor: cursor.tables( - catalog_name=database, schema_name=schema, table_name=identifier - ), + return self._execute_with_cursor( + f"GetTables(database={database}, schema={schema})", + lambda cursor: cursor.list_tables(database=database, schema=schema), ) - @classmethod - def get_open_for_context( - cls, query_header_context: Any = None - ) -> Callable[[Connection], Connection]: - # If there is no node we can simply return the exsting class method open. - # If there is a node create a closure that will call cls._open with the node. - if not query_header_context: - return cls.open + # override + def release(self) -> None: + with self.lock: + conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + if conn is None: + return - def open_for_model(connection: Connection) -> Connection: - return cls._open(connection, query_header_context) + conn._release() - return open_for_model + # override + def cleanup_all(self) -> None: + with self.lock: + for thread_connections in self.threads_compute_connections.values(): + for connection in thread_connections.values(): + if connection.acquire_release_count > 0: + fire_event( + ConnectionLeftOpenInCleanup(conn_name=cast_to_str(connection.name)) + ) + else: + fire_event( + ConnectionClosedInCleanup(conn_name=cast_to_str(connection.name)) + ) + self.close(connection) + + # garbage collect these connections + self.thread_connections.clear() + self.threads_compute_connections.clear() @classmethod def open(cls, connection: Connection) -> Connection: - # Simply call _open with no ResultNode argument. - # Because this is an overridden method we can't just add - # a ResultNode parameter to open. - return cls._open(connection) + databricks_connection = cast(DatabricksDBTConnection, connection) - @classmethod - def _open(cls, connection: Connection, query_header_context: Any = None) -> Connection: if connection.state == ConnectionState.OPEN: return connection @@ -632,45 +445,23 @@ def _open(cls, connection: Connection, query_header_context: Any = None) -> Conn # gotta keep this so we don't prompt users many times cls.credentials_provider = creds.authenticate(cls.credentials_provider) - - invocation_env = creds.get_invocation_env() - user_agent_entry = cls._user_agent - if invocation_env: - user_agent_entry = f"{cls._user_agent}; {invocation_env}" - - connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] - - http_headers: list[tuple[str, str]] = list( - creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() + conn_args = SqlUtils.prepare_connection_arguments( + creds, cls.credentials_provider, databricks_connection.http_path ) - # If a model specifies a compute resource the http path - # may be different than the http_path property of creds. - http_path = _get_http_path(query_header_context, creds) - - def connect() -> DatabricksSQLConnectionWrapper: + def connect() -> DatabricksHandle: try: # TODO: what is the error when a user specifies a catalog they don't have access to - conn: DatabricksSQLConnection = dbsql.connect( - server_hostname=creds.host, - http_path=http_path, - credentials_provider=cls.credentials_provider, - http_headers=http_headers if http_headers else None, - session_configuration=creds.session_properties, - catalog=creds.database, - use_inline_params="silent", - # schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL. - _user_agent_entry=user_agent_entry, - **connection_parameters, + conn = DatabricksHandle.from_connection_args( + conn_args, creds.cluster_id is not None ) - logger.debug(ConnectionCreated(str(conn))) + if conn: + databricks_connection.session_id = conn.session_id + databricks_connection.last_used_time = time.time() - return DatabricksSQLConnectionWrapper( - conn, - is_cluster=creds.cluster_id is not None, - creds=creds, - user_agent=user_agent_entry, - ) + return conn + else: + raise DbtDatabaseError("Failed to create connection") except Error as exc: logger.error(ConnectionCreateError(exc)) raise @@ -692,60 +483,6 @@ def exponential_backoff(attempt: int) -> int: retry_timeout=(timeout if timeout is not None else exponential_backoff), ) - @classmethod - def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse: - _query_id = getattr(cursor, "hex_query_id", None) - if cursor is None: - logger.debug("No cursor was provided. Query ID not available.") - query_id = "N/A" - else: - query_id = _query_id - message = "OK" - return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore - - -class ExtendedSessionConnectionManager(DatabricksConnectionManager): - def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None: - assert ( - USE_LONG_SESSIONS - ), "This connection manager should only be used when USE_LONG_SESSIONS is enabled" - super().__init__(profile, mp_context) - self.threads_compute_connections: dict[ - Hashable, dict[Hashable, DatabricksDBTConnection] - ] = {} - - def set_connection_name( - self, name: Optional[str] = None, query_header_context: Any = None - ) -> Connection: - """Called by 'acquire_connection' in DatabricksAdapter, which is called by - 'connection_named', called by 'connection_for(node)'. - Creates a connection for this thread if one doesn't already - exist, and will rename an existing connection.""" - self._cleanup_idle_connections() - - conn_name: str = "master" if name is None else name - - # Get a connection for this thread - conn = self._get_if_exists_compute_connection(_get_compute_name(query_header_context) or "") - - if conn is None: - conn = self._create_compute_connection(conn_name, query_header_context) - else: # existing connection either wasn't open or didn't have the right name - conn = self._update_compute_connection(conn, conn_name) - - conn._acquire(query_header_context) - - return conn - - # override - def release(self) -> None: - with self.lock: - conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) - if conn is None: - return - - conn._release() - # override @classmethod def close(cls, connection: Connection) -> Connection: @@ -756,46 +493,18 @@ def close(cls, connection: Connection) -> Connection: connection.state = ConnectionState.CLOSED return connection - # override - def cleanup_all(self) -> None: - with self.lock: - for thread_connections in self.threads_compute_connections.values(): - for connection in thread_connections.values(): - if connection.acquire_release_count > 0: - fire_event( - ConnectionLeftOpenInCleanup(conn_name=cast_to_str(connection.name)) - ) - else: - fire_event( - ConnectionClosedInCleanup(conn_name=cast_to_str(connection.name)) - ) - self.close(connection) - - # garbage collect these connections - self.thread_connections.clear() - self.threads_compute_connections.clear() - - def _update_compute_connection( - self, conn: DatabricksDBTConnection, new_name: str - ) -> DatabricksDBTConnection: - if conn.name == new_name and conn.state == ConnectionState.OPEN: - # Found a connection and nothing to do, so just return it - return conn - - orig_conn_name: str = conn.name or "" - - if conn.state != ConnectionState.OPEN: - conn.handle = LazyHandle(self.open) - if conn.name != new_name: - conn.name = new_name - fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=new_name)) - - current_thread_conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) - if current_thread_conn and current_thread_conn.compute_name != conn.compute_name: - self.clear_thread_connection() - self.set_thread_connection(conn) + @classmethod + def get_response(cls, cursor: Any) -> AdapterResponse: + if isinstance(cursor, CursorWrapper): + return cursor.get_response() + else: + return AdapterResponse("OK") - logger.debug(ConnectionReuse(str(conn), orig_conn_name)) + def get_thread_connection(self) -> Connection: + conn = super().get_thread_connection() + self._cleanup_idle_connections() + dbr_conn = cast(DatabricksDBTConnection, conn) + logger.debug(ConnectionRetrieve(str(dbr_conn))) return conn @@ -810,28 +519,6 @@ def _add_compute_connection(self, conn: DatabricksDBTConnection) -> None: ) thread_map[conn.compute_name] = conn - def _get_compute_connections( - self, - ) -> dict[Hashable, DatabricksDBTConnection]: - """Retrieve a map of compute name to connection for the current thread.""" - - thread_id = self.get_thread_identifier() - with self.lock: - thread_map = self.threads_compute_connections.get(thread_id) - if not thread_map: - thread_map = {} - self.threads_compute_connections[thread_id] = thread_map - return thread_map - - def _get_if_exists_compute_connection( - self, compute_name: str - ) -> Optional[DatabricksDBTConnection]: - """Get the connection for the current thread and named compute, if it exists.""" - - with self.lock: - threads_map = self._get_compute_connections() - return threads_map.get(compute_name) - def _cleanup_idle_connections(self) -> None: with self.lock: # Get all connections associated with this thread. There can be multiple connections @@ -855,16 +542,16 @@ def _cleanup_idle_connections(self) -> None: ) and conn._idle_too_long(): logger.debug(ConnectionIdleClose(str(conn))) self.close(conn) - conn._reset_handle(self._open) + conn._reset_handle(self.open) def _create_compute_connection( - self, conn_name: str, query_header_context: Any = None + self, conn_name: str, query_header_context: QueryContextWrapper ) -> DatabricksDBTConnection: """Create anew connection for the combination of current thread and compute associated with the given node.""" # Create a new connection - compute_name = _get_compute_name(query_header_context) or "" + compute_name = query_header_context.compute_name or "" conn = DatabricksDBTConnection( type=Identifier(self.TYPE), @@ -876,9 +563,9 @@ def _create_compute_connection( ) conn.compute_name = compute_name creds = cast(DatabricksCredentials, self.profile.credentials) - conn.http_path = _get_http_path(query_header_context, creds=creds) or "" + conn.http_path = QueryConfigUtils.get_http_path(query_header_context, creds) conn.thread_identifier = cast(tuple[int, int], self.get_thread_identifier()) - conn.max_idle_time = _get_max_idle_time(query_header_context, creds=creds) + conn.max_idle_time = QueryConfigUtils.get_max_idle_time(query_header_context, creds) conn.handle = LazyHandle(self.open) @@ -897,176 +584,103 @@ def _create_compute_connection( return conn - def get_thread_connection(self) -> Connection: - conn = super().get_thread_connection() - self._cleanup_idle_connections() - dbr_conn = cast(DatabricksDBTConnection, conn) - logger.debug(ConnectionRetrieve(str(dbr_conn))) + def _get_if_exists_compute_connection( + self, compute_name: str + ) -> Optional[DatabricksDBTConnection]: + """Get the connection for the current thread and named compute, if it exists.""" - return conn + with self.lock: + threads_map = self._get_compute_connections() + return threads_map.get(compute_name) - @classmethod - def open(cls, connection: Connection) -> Connection: - # Once long session management is no longer under the USE_LONG_SESSIONS toggle - # this should be renamed and replace the _open class method. - assert ( - USE_LONG_SESSIONS - ), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS" + def _get_compute_connections( + self, + ) -> dict[Hashable, DatabricksDBTConnection]: + """Retrieve a map of compute name to connection for the current thread.""" - databricks_connection = cast(DatabricksDBTConnection, connection) + thread_id = self.get_thread_identifier() + with self.lock: + thread_map = self.threads_compute_connections.get(thread_id) + if not thread_map: + thread_map = {} + self.threads_compute_connections[thread_id] = thread_map + return thread_map - if connection.state == ConnectionState.OPEN: - return connection + def _update_compute_connection( + self, conn: DatabricksDBTConnection, new_name: str + ) -> DatabricksDBTConnection: + if conn.name == new_name and conn.state == ConnectionState.OPEN: + # Found a connection and nothing to do, so just return it + return conn - creds: DatabricksCredentials = connection.credentials - timeout = creds.connect_timeout + orig_conn_name: str = conn.name or "" - # gotta keep this so we don't prompt users many times - cls.credentials_provider = creds.authenticate(cls.credentials_provider) + if conn.state != ConnectionState.OPEN: + conn.handle = LazyHandle(self.open) + if conn.name != new_name: + conn.name = new_name + fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=new_name)) - invocation_env = creds.get_invocation_env() - user_agent_entry = cls._user_agent - if invocation_env: - user_agent_entry = f"{cls._user_agent}; {invocation_env}" + current_thread_conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + if current_thread_conn and current_thread_conn.compute_name != conn.compute_name: + self.clear_thread_connection() + self.set_thread_connection(conn) - connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] + logger.debug(ConnectionReuse(str(conn), orig_conn_name)) - http_headers: list[tuple[str, str]] = list( - creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() - ) + return conn - # If a model specifies a compute resource the http path - # may be different than the http_path property of creds. - http_path = databricks_connection.http_path - def connect() -> DatabricksSQLConnectionWrapper: - try: - # TODO: what is the error when a user specifies a catalog they don't have access to - conn = dbsql.connect( - server_hostname=creds.host, - http_path=http_path, - credentials_provider=cls.credentials_provider, - http_headers=http_headers if http_headers else None, - session_configuration=creds.session_properties, - catalog=creds.database, - use_inline_params="silent", - # schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL. - _user_agent_entry=user_agent_entry, - **connection_parameters, - ) +class QueryConfigUtils: + """ + Utility class for getting config values from QueryHeaderContextWrapper and Credentials. + """ - if conn: - databricks_connection.session_id = conn.get_session_id_hex() - databricks_connection.last_used_time = time.time() - logger.debug(ConnectionCreated(str(databricks_connection))) - - return DatabricksSQLConnectionWrapper( - conn, - is_cluster=creds.cluster_id is not None, - creds=creds, - user_agent=user_agent_entry, - ) - except Error as exc: - logger.error(ConnectionCreateError(exc)) - raise + @staticmethod + def get_http_path(context: QueryContextWrapper, creds: DatabricksCredentials) -> str: + """ + Get the http_path for the compute specified for the node. + If none is specified default will be used. + """ - def exponential_backoff(attempt: int) -> int: - return attempt * attempt + if not context.compute_name: + return creds.http_path or "" - retryable_exceptions = [] - # this option is for backwards compatibility - if creds.retry_all: - retryable_exceptions = [Error] + # Get the http_path for the named compute. + http_path = None + if creds.compute: + http_path = creds.compute.get(context.compute_name, {}).get("http_path", None) - return cls.retry_connection( - connection, - connect=connect, - logger=logger, - retryable_exceptions=retryable_exceptions, - retry_limit=creds.connect_retries, - retry_timeout=(timeout if timeout is not None else exponential_backoff), - ) + # no http_path for the named compute resource is an error condition + if not http_path: + raise DbtRuntimeError( + f"Compute resource {context.compute_name} does not exist or " + f"does not specify http_path, relation: {context.relation_name}" + ) + return http_path -def _get_compute_name(query_header_context: Any) -> Optional[str]: - # Get the name of the specified compute resource from the node's - # config. - compute_name = None - if ( - query_header_context - and hasattr(query_header_context, "config") - and query_header_context.config - ): - compute_name = query_header_context.config.get("databricks_compute", None) - return compute_name - - -def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> Optional[str]: - """Get the http_path for the compute specified for the node. - If none is specified default will be used.""" - - thread_id = (os.getpid(), get_ident()) - - # ResultNode *should* have relation_name attr, but we work around a core - # issue by checking. - relation_name = getattr(query_header_context, "relation_name", "[unknown]") - - # If there is no node we return the http_path for the default compute. - if not query_header_context: - if not USE_LONG_SESSIONS: - logger.debug(f"Thread {thread_id}: using default compute resource.") - return creds.http_path - - # Get the name of the compute resource specified in the node's config. - # If none is specified return the http_path for the default compute. - compute_name = _get_compute_name(query_header_context) - if not compute_name: - if not USE_LONG_SESSIONS: - logger.debug(f"On thread {thread_id}: {relation_name} using default compute resource.") - return creds.http_path - - # Get the http_path for the named compute. - http_path = None - if creds.compute: - http_path = creds.compute.get(compute_name, {}).get("http_path", None) - - # no http_path for the named compute resource is an error condition - if not http_path: - raise DbtRuntimeError( - f"Compute resource {compute_name} does not exist or " - f"does not specify http_path, relation: {relation_name}" - ) + @staticmethod + def get_max_idle_time(context: QueryContextWrapper, creds: DatabricksCredentials) -> int: + """Get the http_path for the compute specified for the node. + If none is specified default will be used.""" - if not USE_LONG_SESSIONS: - logger.debug( - f"On thread {thread_id}: {relation_name} using compute resource '{compute_name}'." + max_idle_time = ( + DEFAULT_MAX_IDLE_TIME if creds.connect_max_idle is None else creds.connect_max_idle ) - return http_path - - -def _get_max_idle_time(query_header_context: Any, creds: DatabricksCredentials) -> int: - """Get the http_path for the compute specified for the node. - If none is specified default will be used.""" - - max_idle_time = ( - DEFAULT_MAX_IDLE_TIME if creds.connect_max_idle is None else creds.connect_max_idle - ) - - if query_header_context: - compute_name = _get_compute_name(query_header_context) - if compute_name and creds.compute: - max_idle_time = creds.compute.get(compute_name, {}).get( + if context.compute_name and creds.compute: + max_idle_time = creds.compute.get(context.compute_name, {}).get( "connect_max_idle", max_idle_time ) - if not isinstance(max_idle_time, Number): - if isinstance(max_idle_time, str) and max_idle_time.strip().isnumeric(): - return int(max_idle_time.strip()) - else: - raise DbtRuntimeError( - f"{max_idle_time} is not a valid value for connect_max_idle. " - "Must be a number of seconds." - ) + if not isinstance(max_idle_time, int): + if isinstance(max_idle_time, str) and max_idle_time.strip().isnumeric(): + return int(max_idle_time.strip()) + else: + raise DbtRuntimeError( + f"{max_idle_time} is not a valid value for connect_max_idle. " + "Must be a number of seconds." + ) - return max_idle_time + return max_idle_time diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 7a318cada..250e79f65 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -19,10 +19,10 @@ CredentialSaveError, CredentialShardEvent, ) +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" -DBT_DATABRICKS_INVOCATION_ENV = "DBT_DATABRICKS_INVOCATION_ENV" DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$") EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)") DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS" @@ -74,8 +74,10 @@ class DatabricksCredentials(Credentials): @classmethod def __pre_deserialize__(cls, data: dict[Any, Any]) -> dict[Any, Any]: data = super().__pre_deserialize__(data) - if "database" not in data: - data["database"] = None + data.setdefault("database", None) + data.setdefault("connection_parameters", {}) + data["connection_parameters"].setdefault("_retry_stop_after_attempts_count", 30) + data["connection_parameters"].setdefault("_retry_delay_max", 60) return data def __post_init__(self) -> None: @@ -150,7 +152,7 @@ def validate_creds(self) -> None: @classmethod def get_invocation_env(cls) -> Optional[str]: - invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV) + invocation_env = GlobalState.get_invocation_env() if invocation_env: # Thrift doesn't allow nested () so we need to ensure # that the passed user agent is valid. @@ -160,9 +162,7 @@ def get_invocation_env(cls) -> Optional[str]: @classmethod def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]: - http_session_headers_str: Optional[str] = os.environ.get( - DBT_DATABRICKS_HTTP_SESSION_HEADERS - ) + http_session_headers_str = GlobalState.get_http_session_headers() http_session_headers_dict: dict[str, str] = ( { diff --git a/dbt/adapters/databricks/events/connection_events.py b/dbt/adapters/databricks/events/connection_events.py index 9f8ec8c1b..3f533c1d1 100644 --- a/dbt/adapters/databricks/events/connection_events.py +++ b/dbt/adapters/databricks/events/connection_events.py @@ -17,30 +17,6 @@ def __str__(self) -> str: return f"Connection(session-id={self.session_id}) - {self.message}" -class ConnectionCancel(ConnectionEvent): - def __init__(self, connection: Optional[Connection]): - super().__init__(connection, "Cancelling connection") - - -class ConnectionClose(ConnectionEvent): - def __init__(self, connection: Optional[Connection]): - super().__init__(connection, "Closing connection") - - -class ConnectionCancelError(ConnectionEvent): - def __init__(self, connection: Optional[Connection], exception: Exception): - super().__init__( - connection, str(SQLErrorEvent(exception, "Exception while trying to cancel connection")) - ) - - -class ConnectionCloseError(ConnectionEvent): - def __init__(self, connection: Optional[Connection], exception: Exception): - super().__init__( - connection, str(SQLErrorEvent(exception, "Exception while trying to close connection")) - ) - - class ConnectionCreateError(ConnectionEvent): def __init__(self, exception: Exception): super().__init__( diff --git a/dbt/adapters/databricks/events/cursor_events.py b/dbt/adapters/databricks/events/cursor_events.py deleted file mode 100644 index d94a002a2..000000000 --- a/dbt/adapters/databricks/events/cursor_events.py +++ /dev/null @@ -1,59 +0,0 @@ -from abc import ABC -from uuid import UUID - -from databricks.sql.client import Cursor - -from dbt.adapters.databricks.events.base import SQLErrorEvent - - -class CursorEvent(ABC): - def __init__(self, cursor: Cursor, message: str): - self.message = message - self.session_id = "Unknown" - self.command_id = "Unknown" - if cursor: - if cursor.connection: - self.session_id = cursor.connection.get_session_id_hex() - if ( - cursor.active_result_set - and cursor.active_result_set.command_id - and cursor.active_result_set.command_id.operationId - ): - self.command_id = ( - str(UUID(bytes=cursor.active_result_set.command_id.operationId.guid)) - or "Unknown" - ) - - def __str__(self) -> str: - return ( - f"Cursor(session-id={self.session_id}, command-id={self.command_id}) - {self.message}" - ) - - -class CursorCloseError(CursorEvent): - def __init__(self, cursor: Cursor, exception: Exception): - super().__init__( - cursor, str(SQLErrorEvent(exception, "Exception while trying to close cursor")) - ) - - -class CursorCancelError(CursorEvent): - def __init__(self, cursor: Cursor, exception: Exception): - super().__init__( - cursor, str(SQLErrorEvent(exception, "Exception while trying to cancel cursor")) - ) - - -class CursorCreate(CursorEvent): - def __init__(self, cursor: Cursor): - super().__init__(cursor, "Created cursor") - - -class CursorClose(CursorEvent): - def __init__(self, cursor: Cursor): - super().__init__(cursor, "Closing cursor") - - -class CursorCancel(CursorEvent): - def __init__(self, cursor: Cursor): - super().__init__(cursor, "Cancelling cursor") diff --git a/dbt/adapters/databricks/global_state.py b/dbt/adapters/databricks/global_state.py new file mode 100644 index 000000000..e01bb0ce0 --- /dev/null +++ b/dbt/adapters/databricks/global_state.py @@ -0,0 +1,52 @@ +import os +from typing import ClassVar, Optional + +from dbt.adapters.databricks.__version__ import version as __version__ + + +class GlobalState: + """Global state is a bad idea, but since we don't control instantiation, better to have it in a + single place than scattered throughout the codebase. + """ + + __invocation_env: ClassVar[Optional[str]] = None + __invocation_env_set: ClassVar[bool] = False + + USER_AGENT = f"dbt-databricks/{__version__}" + + @classmethod + def get_invocation_env(cls) -> Optional[str]: + if not cls.__invocation_env_set: + cls.__invocation_env = os.getenv("DBT_DATABRICKS_INVOCATION_ENV") + cls.__invocation_env_set = True + return cls.__invocation_env + + __session_headers: ClassVar[Optional[str]] = None + __session_headers_set: ClassVar[bool] = False + + @classmethod + def get_http_session_headers(cls) -> Optional[str]: + if not cls.__session_headers_set: + cls.__session_headers = os.getenv("DBT_DATABRICKS_HTTP_SESSION_HEADERS") + cls.__session_headers_set = True + return cls.__session_headers + + __describe_char_bypass: ClassVar[Optional[bool]] = None + + @classmethod + def get_char_limit_bypass(cls) -> bool: + if cls.__describe_char_bypass is None: + cls.__describe_char_bypass = ( + os.getenv("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE" + ) + return cls.__describe_char_bypass + + __connector_log_level: ClassVar[Optional[str]] = None + + @classmethod + def get_connector_log_level(cls) -> str: + if cls.__connector_log_level is None: + cls.__connector_log_level = os.getenv( + "DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN" + ).upper() + return cls.__connector_log_level diff --git a/dbt/adapters/databricks/handle.py b/dbt/adapters/databricks/handle.py new file mode 100644 index 000000000..1b7d30dbe --- /dev/null +++ b/dbt/adapters/databricks/handle.py @@ -0,0 +1,324 @@ +import decimal +import re +import sys +from collections.abc import Callable, Sequence +from types import TracebackType +from typing import TYPE_CHECKING, Any, Optional, TypeVar + +from dbt_common.exceptions import DbtRuntimeError + +import databricks.sql as dbsql +from databricks.sql.client import Connection, Cursor +from dbt.adapters.contracts.connection import AdapterResponse +from dbt.adapters.databricks import utils +from dbt.adapters.databricks.__version__ import version as __version__ +from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider +from dbt.adapters.databricks.logging import logger + +if TYPE_CHECKING: + pass + + +CursorOp = Callable[[Cursor], None] +CursorExecOp = Callable[[Cursor], Cursor] +CursorWrapperOp = Callable[["CursorWrapper"], None] +ConnectionOp = Callable[[Optional[Connection]], None] +LogOp = Callable[[], str] +FailLogOp = Callable[[Exception], str] + + +class CursorWrapper: + """ + Wrap the DBSQL cursor to abstract the details from DatabricksConnectionManager. + """ + + def __init__(self, cursor: Cursor): + self._cursor = cursor + self.open = True + + @property + def description(self) -> Optional[list[tuple]]: + return self._cursor.description + + def cancel(self) -> None: + if self._cursor.active_op_handle: + self._cleanup( + lambda cursor: cursor.cancel(), + lambda: f"{self} - Cancelling", + lambda ex: f"{self} - Exception while cancelling: {ex}", + ) + + def close(self) -> None: + self._cleanup( + lambda cursor: cursor.close(), + lambda: f"{self} - Closing", + lambda ex: f"{self} - Exception while closing: {ex}", + ) + + def _cleanup( + self, + cleanup: CursorOp, + startLog: LogOp, + failLog: FailLogOp, + ) -> None: + """ + Common cleanup function for cursor operations, handling either close or cancel. + """ + if self.open: + self.open = False + logger.debug(startLog()) + utils.handle_exceptions_as_warning(lambda: cleanup(self._cursor), failLog) + + def fetchall(self) -> Sequence[tuple]: + return self._safe_execute(lambda cursor: cursor.fetchall()) + + def fetchone(self) -> Optional[tuple]: + return self._safe_execute(lambda cursor: cursor.fetchone()) + + def fetchmany(self, size: int) -> Sequence[tuple]: + return self._safe_execute(lambda cursor: cursor.fetchmany(size)) + + def get_response(self) -> AdapterResponse: + return AdapterResponse(_message="OK", query_id=self._cursor.query_id or "N/A") + + T = TypeVar("T") + + def _safe_execute(self, f: Callable[[Cursor], T]) -> T: + if not self.open: + raise DbtRuntimeError("Attempting to execute on a closed cursor") + return f(self._cursor) + + def __str__(self) -> str: + return ( + f"Cursor(session-id={self._cursor.connection.get_session_id_hex()}, " + f"command-id={self._cursor.query_id})" + ) + + def __enter__(self) -> "CursorWrapper": + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + self.close() + return exc_val is None + + +class DatabricksHandle: + """ + Handle for a Databricks SQL Session. + Provides a layer of abstraction over the Databricks SQL client library such that + DatabricksConnectionManager does not depend on the details of this library directly. + """ + + def __init__( + self, + conn: Connection, + is_cluster: bool, + ): + self._conn = conn + self.open = True + self._cursor: Optional[CursorWrapper] = None + self._dbr_version: Optional[tuple[int, int]] = None + self._is_cluster = is_cluster + + @property + def dbr_version(self) -> tuple[int, int]: + """ + Gets the DBR version of the current session. + """ + if not self._dbr_version: + if self._is_cluster: + cursor = self._safe_execute( + lambda cursor: cursor.execute( + "SET spark.databricks.clusterUsageTags.sparkVersion" + ) + ) + results = cursor.fetchone() + self._dbr_version = SqlUtils.extract_dbr_version(results[1] if results else "") + cursor.close() + else: + # Assuming SQL Warehouse uses the latest version. + self._dbr_version = (sys.maxsize, sys.maxsize) + + return self._dbr_version + + @property + def session_id(self) -> str: + return self._conn.get_session_id_hex() + + def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> CursorWrapper: + """ + Execute a SQL statement on the current session with optional bindings. + """ + return self._safe_execute( + lambda cursor: cursor.execute( + SqlUtils.clean_sql(sql), SqlUtils.translate_bindings(bindings) + ) + ) + + def list_schemas(self, database: str, schema: Optional[str] = None) -> CursorWrapper: + """ + Get a cursor for listing schemas in the given database. + """ + return self._safe_execute(lambda cursor: cursor.schemas(database, schema)) + + def list_tables(self, database: str, schema: str) -> "CursorWrapper": + """ + Get a cursor for listing tables in the given database and schema. + """ + + return self._safe_execute(lambda cursor: cursor.tables(database, schema)) + + def cancel(self) -> None: + """ + Cancel in progress query, if any, then close connection and cursor. + """ + self._cleanup( + lambda cursor: cursor.cancel(), + lambda: f"{self} - Cancelling", + lambda ex: f"{self} - Exception while cancelling: {ex}", + ) + + def close(self) -> None: + """ + Close the connection and cursor. + """ + + self._cleanup( + lambda cursor: cursor.close(), + lambda: f"{self} - Closing", + lambda ex: f"{self} - Exception while closing: {ex}", + ) + + def rollback(self) -> None: + """ + Required for interface compatibility, but not implemented. + """ + logger.debug("NotImplemented: rollback") + + @staticmethod + def from_connection_args( + conn_args: dict[str, Any], is_cluster: bool + ) -> Optional["DatabricksHandle"]: + """ + Create a new DatabricksHandle from the given connection arguments. + """ + + conn = dbsql.connect(**conn_args) + if not conn: + logger.warning(f"Failed to create connection for {conn_args.get('http_path')}") + return None + connection = DatabricksHandle(conn, is_cluster=is_cluster) + logger.debug(f"{connection} - Created") + + return connection + + def _cleanup( + self, + cursor_op: CursorWrapperOp, + startLog: LogOp, + failLog: FailLogOp, + ) -> None: + """ + Function for cleaning up the connection and cursor, handling either close or cancel. + """ + if self.open: + self.open = False + logger.debug(startLog()) + + if self._cursor: + cursor_op(self._cursor) + + utils.handle_exceptions_as_warning(lambda: self._conn.close(), failLog) + + def _safe_execute(self, f: CursorExecOp) -> CursorWrapper: + """ + Ensure that a previously opened cursor is closed and that a new one is created + before executing the given function. + Also ensures that we do not continue to execute SQL after a connection cleanup + has been requested. + """ + + if not self.open: + raise DbtRuntimeError("Attempting to execute on a closed connection") + assert self._conn, "Should not be possible for _conn to be None if open" + if self._cursor: + self._cursor.close() + self._cursor = CursorWrapper(f(self._conn.cursor())) + return self._cursor + + def __del__(self) -> None: + if self._cursor: + self._cursor.close() + + self.close() + + def __str__(self) -> str: + return f"Connection(session-id={self.session_id})" + + +class SqlUtils: + """ + Utility class for preparing cursor input/output. + """ + + DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)") + user_agent = f"dbt-databricks/{__version__}" + + @staticmethod + def extract_dbr_version(version: str) -> tuple[int, int]: + m = SqlUtils.DBR_VERSION_REGEX.search(version) + if m: + major = int(m.group(1)) + if m.group(2) == "x": + minor = sys.maxsize + else: + minor = int(m.group(2)) + return (major, minor) + else: + raise DbtRuntimeError("Failed to detect DBR version") + + @staticmethod + def translate_bindings(bindings: Optional[Sequence[Any]]) -> Optional[Sequence[Any]]: + if bindings: + return list(map(lambda x: float(x) if isinstance(x, decimal.Decimal) else x, bindings)) + return None + + @staticmethod + def clean_sql(sql: str) -> str: + cleaned = sql.strip() + if cleaned.endswith(";"): + cleaned = cleaned[:-1] + return cleaned + + @staticmethod + def prepare_connection_arguments( + creds: DatabricksCredentials, creds_provider: TCredentialProvider, http_path: str + ) -> dict[str, Any]: + invocation_env = creds.get_invocation_env() + user_agent_entry = SqlUtils.user_agent + if invocation_env: + user_agent_entry += f"; {invocation_env}" + + connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] + + http_headers: list[tuple[str, str]] = list( + creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() + ) + + return { + "server_hostname": creds.host, + "http_path": http_path, + "credentials_provider": creds_provider, + "http_headers": http_headers if http_headers else None, + "session_configuration": creds.session_properties, + "catalog": creds.database, + "use_inline_params": "silent", + "schema": creds.schema, + "_user_agent_entry": user_agent_entry, + **connection_parameters, + } diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index dce432c96..24dd4d279 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -1,4 +1,4 @@ -import os +import posixpath import re from abc import ABC, abstractmethod from collections import defaultdict @@ -31,11 +31,8 @@ GetColumnsByInformationSchema, ) from dbt.adapters.databricks.column import DatabricksColumn -from dbt.adapters.databricks.connections import ( - USE_LONG_SESSIONS, - DatabricksConnectionManager, - ExtendedSessionConnectionManager, -) +from dbt.adapters.databricks.connections import DatabricksConnectionManager +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.python_models.python_submissions import ( AllPurposeClusterPythonJobHelper, JobClusterPythonJobHelper, @@ -113,6 +110,7 @@ class DatabricksConfig(AdapterConfig): partition_by: Optional[Union[list[str], str]] = None clustered_by: Optional[Union[list[str], str]] = None liquid_clustered_by: Optional[Union[list[str], str]] = None + auto_liquid_cluster: Optional[bool] = None buckets: Optional[int] = None options: Optional[dict[str, str]] = None merge_update_columns: Optional[str] = None @@ -142,8 +140,8 @@ def get_identifier_list_string(table_names: set[str]) -> str: """ _identifier = "|".join(table_names) - bypass_2048_char_limit = os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "false") - if bypass_2048_char_limit == "true": + bypass_2048_char_limit = GlobalState.get_char_limit_bypass() + if bypass_2048_char_limit: _identifier = _identifier if len(_identifier) < 2048 else "*" return _identifier @@ -154,10 +152,7 @@ class DatabricksAdapter(SparkAdapter): Relation = DatabricksRelation Column = DatabricksColumn - if USE_LONG_SESSIONS: - ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager - else: - ConnectionManager = DatabricksConnectionManager + ConnectionManager = DatabricksConnectionManager connections: DatabricksConnectionManager @@ -200,9 +195,10 @@ def update_tblproperties_for_iceberg( raise DbtConfigError( "When table_format is 'iceberg', cannot set file_format to other than delta." ) - if config.get("materialized") not in ("incremental", "table"): + if config.get("materialized") not in ("incremental", "table", "snapshot"): raise DbtConfigError( - "When table_format is 'iceberg', materialized must be 'incremental' or 'table'." + "When table_format is 'iceberg', materialized must be 'incremental'" + ", 'table', or 'snapshot'." ) result["delta.enableIcebergCompatV2"] = "true" result["delta.universalFormat.enabledFormats"] = "iceberg" @@ -220,9 +216,9 @@ def compute_external_path( raise DbtConfigError("location_root is required for external tables.") include_full_name_in_path = config.get("include_full_name_in_path", False) if include_full_name_in_path: - path = os.path.join(location_root, database, schema, identifier) + path = posixpath.join(location_root, database, schema, identifier) else: - path = os.path.join(location_root, identifier) + path = posixpath.join(location_root, identifier) if is_incremental: path = path + "_tmp" return path @@ -631,9 +627,9 @@ def add_query( def run_sql_for_tests( self, sql: str, fetch: str, conn: Connection ) -> Optional[Union[Optional[tuple], list[tuple]]]: - cursor = conn.handle.cursor() + handle = conn.handle try: - cursor.execute(sql) + cursor = handle.execute(sql) if fetch == "one": return cursor.fetchone() elif fetch == "all": @@ -645,7 +641,8 @@ def run_sql_for_tests( print(e) raise finally: - cursor.close() + if cursor: + cursor.close() conn.transaction_open = False @available diff --git a/dbt/adapters/databricks/logging.py b/dbt/adapters/databricks/logging.py index d0f1d42ba..81e7449e1 100644 --- a/dbt/adapters/databricks/logging.py +++ b/dbt/adapters/databricks/logging.py @@ -1,7 +1,7 @@ -import os from logging import Handler, LogRecord, getLogger from typing import Union +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.events.logging import AdapterLogger logger = AdapterLogger("Databricks") @@ -22,7 +22,7 @@ def emit(self, record: LogRecord) -> None: dbt_adapter_logger = AdapterLogger("databricks-sql-connector") pysql_logger = getLogger("databricks.sql") -pysql_logger_level = os.environ.get("DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN").upper() +pysql_logger_level = GlobalState.get_connector_log_level() pysql_logger.setLevel(pysql_logger_level) pysql_handler = DbtCoreHandler(dbt_logger=dbt_adapter_logger, level=pysql_logger_level) diff --git a/dbt/adapters/databricks/python_models/python_config.py b/dbt/adapters/databricks/python_models/python_config.py index 29aa44efa..5a39a5c98 100644 --- a/dbt/adapters/databricks/python_models/python_config.py +++ b/dbt/adapters/databricks/python_models/python_config.py @@ -36,6 +36,8 @@ class PythonModelConfig(BaseModel): cluster_id: Optional[str] = None http_path: Optional[str] = None create_notebook: bool = False + environment_key: Optional[str] = None + environment_dependencies: list[str] = Field(default_factory=list) class ParsedPythonModel(BaseModel): diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index afcb383c2..f98db234a 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -209,6 +209,8 @@ def __init__( self.job_grants = parsed_model.config.python_job_config.grants self.acls = parsed_model.config.access_control_list self.additional_job_settings = parsed_model.config.python_job_config.dict() + self.environment_key = parsed_model.config.environment_key + self.environment_deps = parsed_model.config.environment_dependencies def compile(self, path: str) -> PythonJobDetails: job_spec: dict[str, Any] = { @@ -217,9 +219,20 @@ def compile(self, path: str) -> PythonJobDetails: "notebook_path": path, }, } - job_spec.update(self.cluster_spec) # updates 'new_cluster' config additional_job_config = self.additional_job_settings + + if self.environment_key: + job_spec["environment_key"] = self.environment_key + if self.environment_deps and not self.additional_job_settings.get("environments"): + additional_job_config["environments"] = [ + { + "environment_key": self.environment_key, + "spec": {"client": "2", "dependencies": self.environment_deps}, + } + ] + job_spec.update(self.cluster_spec) # updates 'new_cluster' config + access_control_list = self.permission_builder.build_job_permissions( self.job_grants, self.acls ) diff --git a/dbt/adapters/databricks/utils.py b/dbt/adapters/databricks/utils.py index 3dfd4096f..dccdd16c7 100644 --- a/dbt/adapters/databricks/utils.py +++ b/dbt/adapters/databricks/utils.py @@ -6,6 +6,7 @@ from jinja2 import Undefined from dbt.adapters.base import BaseAdapter +from dbt.adapters.databricks.logging import logger from dbt.adapters.spark.impl import TABLE_OR_VIEW_NOT_FOUND_MESSAGES if TYPE_CHECKING: @@ -32,7 +33,7 @@ def _redact_credentials_in_copy_into(sql: str) -> str: f"{key.strip()} = '[REDACTED]'" for key, _ in (pair.strip().split("=", 1) for pair in m.group(1).split(",")) ) - return f"{sql[: m.start()]} ({redacted}){sql[m.end():]}" + return f"{sql[: m.start()]} ({redacted}){sql[m.end() :]}" else: return sql @@ -77,3 +78,13 @@ def handle_missing_objects(exec: Callable[[], T], default: T) -> T: def quote(name: str) -> str: return f"`{name}`" + + +ExceptionToStrOp = Callable[[Exception], str] + + +def handle_exceptions_as_warning(op: Callable[[], None], log_gen: ExceptionToStrOp) -> None: + try: + op() + except Exception as e: + logger.warning(log_gen(e)) diff --git a/dbt/include/databricks/macros/adapters/columns.sql b/dbt/include/databricks/macros/adapters/columns.sql index 7fe40e6f9..d9b041ccd 100644 --- a/dbt/include/databricks/macros/adapters/columns.sql +++ b/dbt/include/databricks/macros/adapters/columns.sql @@ -32,13 +32,13 @@ {{ exceptions.raise_compiler_error('Delta format required for dropping columns from tables') }} {% endif %} {%- call statement('alter_relation_remove_columns') -%} - ALTER TABLE {{ relation }} DROP COLUMNS ({{ api.Column.format_remove_column_list(remove_columns) }}) + ALTER TABLE {{ relation.render() }} DROP COLUMNS ({{ api.Column.format_remove_column_list(remove_columns) }}) {%- endcall -%} {% endif %} {% if add_columns %} {%- call statement('alter_relation_add_columns') -%} - ALTER TABLE {{ relation }} ADD COLUMNS ({{ api.Column.format_add_column_list(add_columns) }}) + ALTER TABLE {{ relation.render() }} ADD COLUMNS ({{ api.Column.format_add_column_list(add_columns) }}) {%- endcall -%} {% endif %} {% endmacro %} \ No newline at end of file diff --git a/dbt/include/databricks/macros/adapters/persist_docs.sql b/dbt/include/databricks/macros/adapters/persist_docs.sql index 873039e8d..a8ad48bab 100644 --- a/dbt/include/databricks/macros/adapters/persist_docs.sql +++ b/dbt/include/databricks/macros/adapters/persist_docs.sql @@ -4,7 +4,7 @@ {% set comment = column['description'] %} {% set escaped_comment = comment | replace('\'', '\\\'') %} {% set comment_query %} - alter table {{ relation }} change column {{ api.Column.get_name(column) }} comment '{{ escaped_comment }}'; + alter table {{ relation.render()|lower }} change column {{ api.Column.get_name(column) }} comment '{{ escaped_comment }}'; {% endset %} {% do run_query(comment_query) %} {% endfor %} @@ -13,7 +13,7 @@ {% macro alter_table_comment(relation, model) %} {% set comment_query %} - comment on table {{ relation|lower }} is '{{ model.description | replace("'", "\\'") }}' + comment on table {{ relation.render()|lower }} is '{{ model.description | replace("'", "\\'") }}' {% endset %} {% do run_query(comment_query) %} {% endmacro %} diff --git a/dbt/include/databricks/macros/materializations/seeds/helpers.sql b/dbt/include/databricks/macros/materializations/seeds/helpers.sql index df690f18b..82acaba3d 100644 --- a/dbt/include/databricks/macros/materializations/seeds/helpers.sql +++ b/dbt/include/databricks/macros/materializations/seeds/helpers.sql @@ -6,7 +6,7 @@ {% set batch_size = get_batch_size() %} {% set column_override = model['config'].get('column_types', {}) %} - {% set must_cast = model['config'].get("file_format", "delta") == "parquet" %} + {% set must_cast = model['config'].get('file_format', 'delta') == 'parquet' %} {% set statements = [] %} diff --git a/dbt/include/databricks/macros/materializations/snapshot.sql b/dbt/include/databricks/macros/materializations/snapshot.sql index 3d1236a15..3a513a24d 100644 --- a/dbt/include/databricks/macros/materializations/snapshot.sql +++ b/dbt/include/databricks/macros/materializations/snapshot.sql @@ -1,27 +1,4 @@ -{% macro databricks_build_snapshot_staging_table(strategy, sql, target_relation) %} - {% set tmp_identifier = target_relation.identifier ~ '__dbt_tmp' %} - - {%- set tmp_relation = api.Relation.create(identifier=tmp_identifier, - schema=target_relation.schema, - database=target_relation.database, - type='view') -%} - - {% set select = snapshot_staging_table(strategy, sql, target_relation) %} - - {# needs to be a non-temp view so that its columns can be ascertained via `describe` #} - {% call statement('build_snapshot_staging_relation') %} - create or replace view {{ tmp_relation }} - as - {{ select }} - {% endcall %} - - {% do return(tmp_relation) %} -{% endmacro %} - - {% materialization snapshot, adapter='databricks' %} - {%- set config = model['config'] -%} - {%- set target_table = model.get('alias', model.get('name')) -%} {%- set strategy_name = config.get('strategy') -%} @@ -62,47 +39,43 @@ {{ run_hooks(pre_hooks, inside_transaction=True) }} {% set strategy_macro = strategy_dispatch(strategy_name) %} - {% set strategy = strategy_macro(model, "snapshotted_data", "source_data", config, target_relation_exists) %} + {% set strategy = strategy_macro(model, "snapshotted_data", "source_data", model['config'], target_relation_exists) %} {% if not target_relation_exists %} {% set build_sql = build_snapshot_table(strategy, model['compiled_code']) %} + {% set build_or_select_sql = build_sql %} {% set final_sql = create_table_as(False, target_relation, build_sql) %} - {% call statement('main') %} - {{ final_sql }} - {% endcall %} - - {% do persist_docs(target_relation, model, for_relation=False) %} - {% else %} - {{ adapter.valid_snapshot_target(target_relation) }} + {% set columns = config.get("snapshot_table_column_names") or get_snapshot_table_column_names() %} - {% if target_relation.database is none %} - {% set staging_table = spark_build_snapshot_staging_table(strategy, sql, target_relation) %} - {% else %} - {% set staging_table = databricks_build_snapshot_staging_table(strategy, sql, target_relation) %} - {% endif %} + {{ adapter.assert_valid_snapshot_target_given_strategy(target_relation, columns, strategy) }} + + {% set build_or_select_sql = snapshot_staging_table(strategy, sql, target_relation) %} + {% set staging_table = build_snapshot_staging_table(strategy, sql, target_relation) %} -- this may no-op if the database does not require column expansion {% do adapter.expand_target_column_types(from_relation=staging_table, to_relation=target_relation) %} + {% set remove_columns = ['dbt_change_type', 'DBT_CHANGE_TYPE', 'dbt_unique_key', 'DBT_UNIQUE_KEY'] %} + {% if unique_key | is_list %} + {% for key in strategy.unique_key %} + {{ remove_columns.append('dbt_unique_key_' + loop.index|string) }} + {{ remove_columns.append('DBT_UNIQUE_KEY_' + loop.index|string) }} + {% endfor %} + {% endif %} + {% set missing_columns = adapter.get_missing_columns(staging_table, target_relation) - | rejectattr('name', 'equalto', 'dbt_change_type') - | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') - | rejectattr('name', 'equalto', 'dbt_unique_key') - | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | rejectattr('name', 'in', remove_columns) | list %} {% do create_columns(target_relation, missing_columns) %} {% set source_columns = adapter.get_columns_in_relation(staging_table) - | rejectattr('name', 'equalto', 'dbt_change_type') - | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') - | rejectattr('name', 'equalto', 'dbt_unique_key') - | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | rejectattr('name', 'in', remove_columns) | list %} {% set quoted_source_columns = [] %} @@ -117,23 +90,34 @@ ) %} - {% call statement_with_staging_table('main', staging_table) %} - {{ final_sql }} - {% endcall %} + {% endif %} - {% do persist_docs(target_relation, model, for_relation=True) %} - {% endif %} + {{ check_time_data_types(build_or_select_sql) }} - {% set should_revoke = should_revoke(target_relation_exists, full_refresh_mode) %} - {% do apply_grants(target_relation, grant_config, should_revoke) %} + {% call statement('main') %} + {{ final_sql }} + {% endcall %} - {% do persist_constraints(target_relation, model) %} + {% set should_revoke = should_revoke(target_relation_exists, full_refresh_mode=False) %} + {% do apply_grants(target_relation, grant_config, should_revoke=should_revoke) %} + + {% do persist_docs(target_relation, model) %} + + {% if not target_relation_exists %} + {% do create_indexes(target_relation) %} + {% endif %} {{ run_hooks(post_hooks, inside_transaction=True) }} {{ adapter.commit() }} + {% if staging_table is defined %} + {% do post_snapshot(staging_table) %} + {% endif %} + + {% do persist_constraints(target_relation, model) %} + {{ run_hooks(post_hooks, inside_transaction=False) }} {{ return({'relations': [target_relation]}) }} diff --git a/dbt/include/databricks/macros/relations/constraints.sql b/dbt/include/databricks/macros/relations/constraints.sql index 68f3a44fe..6d999823a 100644 --- a/dbt/include/databricks/macros/relations/constraints.sql +++ b/dbt/include/databricks/macros/relations/constraints.sql @@ -106,19 +106,19 @@ {% macro get_constraint_sql(relation, constraint, model, column={}) %} {% set statements = [] %} - {% set type = constraint.get("type", "") %} + {% set type = constraint.get('type', '') %} {% if type == 'check' %} - {% set expression = constraint.get("expression", "") %} + {% set expression = constraint.get('expression', '') %} {% if not expression %} {{ exceptions.raise_compiler_error('Invalid check constraint expression') }} {% endif %} - {% set name = constraint.get("name") %} + {% set name = constraint.get('name') %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} - {%- set name = local_md5 (relation.identifier ~ ";" ~ column.get("name", "") ~ ";" ~ expression ~ ";") -%} + {%- set name = local_md5 (relation.identifier ~ ";" ~ column.get('name', '') ~ ";" ~ expression ~ ";") -%} {% else %} {{ exceptions.raise_compiler_error("Constraint of type " ~ type ~ " with no `name` provided, and no md5 utility.") }} {% endif %} @@ -126,7 +126,7 @@ {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " check (" ~ expression ~ ");" %} {% do statements.append(stmt) %} {% elif type == 'not_null' %} - {% set column_names = constraint.get("columns", []) %} + {% set column_names = constraint.get('columns', []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} @@ -134,7 +134,7 @@ {% set column = model.get('columns', {}).get(column_name) %} {% if column %} {% set quoted_name = api.Column.get_name(column) %} - {% set stmt = "alter table " ~ relation ~ " change column " ~ quoted_name ~ " set not null " ~ (constraint.expression or "") ~ ";" %} + {% set stmt = "alter table " ~ relation.render() ~ " change column " ~ quoted_name ~ " set not null " ~ (constraint.expression or "") ~ ";" %} {% do statements.append(stmt) %} {% else %} {{ exceptions.warn('not_null constraint on invalid column: ' ~ column_name) }} @@ -144,7 +144,7 @@ {% if constraint.get('warn_unenforced') %} {{ exceptions.warn("unenforced constraint type: " ~ type)}} {% endif %} - {% set column_names = constraint.get("columns", []) %} + {% set column_names = constraint.get('columns', []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} @@ -161,7 +161,7 @@ {% set joined_names = quoted_names|join(", ") %} - {% set name = constraint.get("name") %} + {% set name = constraint.get('name') %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} @@ -170,7 +170,7 @@ {{ exceptions.raise_compiler_error("Constraint of type " ~ type ~ " with no `name` provided, and no md5 utility.") }} {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " primary key(" ~ joined_names ~ ");" %} + {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " primary key(" ~ joined_names ~ ");" %} {% do statements.append(stmt) %} {% elif type == 'foreign_key' %} @@ -178,7 +178,7 @@ {{ exceptions.warn("unenforced constraint type: " ~ constraint.type)}} {% endif %} - {% set name = constraint.get("name") %} + {% set name = constraint.get('name') %} {% if constraint.get('expression') %} @@ -191,9 +191,9 @@ {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key" ~ constraint.get('expression') %} + {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " foreign key" ~ constraint.get('expression') %} {% else %} - {% set column_names = constraint.get("columns", []) %} + {% set column_names = constraint.get('columns', []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} @@ -210,7 +210,7 @@ {% set joined_names = quoted_names|join(", ") %} - {% set parent = constraint.get("to") %} + {% set parent = constraint.get('to') %} {% if not parent %} {{ exceptions.raise_compiler_error('No parent table defined for foreign key: ' ~ expression) }} {% endif %} @@ -227,8 +227,8 @@ {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key(" ~ joined_names ~ ") references " ~ parent %} - {% set parent_columns = constraint.get("to_columns") %} + {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " foreign key(" ~ joined_names ~ ") references " ~ parent %} + {% set parent_columns = constraint.get('to_columns') %} {% if parent_columns %} {% set stmt = stmt ~ "(" ~ parent_columns|join(", ") ~ ")"%} {% endif %} @@ -236,13 +236,13 @@ {% set stmt = stmt ~ ";" %} {% do statements.append(stmt) %} {% elif type == 'custom' %} - {% set expression = constraint.get("expression", "") %} + {% set expression = constraint.get('expression', '') %} {% if not expression %} {{ exceptions.raise_compiler_error('Missing custom constraint expression') }} {% endif %} - {% set name = constraint.get("name") %} - {% set expression = constraint.get("expression") %} + {% set name = constraint.get('name') %} + {% set expression = constraint.get('expression') %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} @@ -251,7 +251,7 @@ {{ exceptions.raise_compiler_error("Constraint of type " ~ type ~ " with no `name` provided, and no md5 utility.") }} {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " " ~ expression ~ ";" %} + {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " " ~ expression ~ ";" %} {% do statements.append(stmt) %} {% elif constraint.get('warn_unsupported') %} {{ exceptions.warn("unsupported constraint type: " ~ constraint.type)}} @@ -264,15 +264,15 @@ {# convert constraints defined using the original databricks format #} {% set dbt_constraints = [] %} {% for constraint in constraints %} - {% if constraint.get and constraint.get("type") %} + {% if constraint.get and constraint.get('type') %} {# already in model contract format #} {% do dbt_constraints.append(constraint) %} {% else %} {% if column %} {% if constraint == "not_null" %} - {% do dbt_constraints.append({"type": "not_null", "columns": [column.get("name")]}) %} + {% do dbt_constraints.append({"type": "not_null", "columns": [column.get('name')]}) %} {% else %} - {{ exceptions.raise_compiler_error('Invalid constraint for column ' ~ column.get("name", "") ~ '. Only `not_null` is supported.') }} + {{ exceptions.raise_compiler_error('Invalid constraint for column ' ~ column.get('name', "") ~ '. Only `not_null` is supported.') }} {% endif %} {% else %} {% set name = constraint['name'] %} diff --git a/dbt/include/databricks/macros/relations/liquid_clustering.sql b/dbt/include/databricks/macros/relations/liquid_clustering.sql index 3cf810488..43a3b113e 100644 --- a/dbt/include/databricks/macros/relations/liquid_clustering.sql +++ b/dbt/include/databricks/macros/relations/liquid_clustering.sql @@ -1,21 +1,29 @@ {% macro liquid_clustered_cols() -%} {%- set cols = config.get('liquid_clustered_by', validator=validation.any[list, basestring]) -%} + {%- set auto_cluster = config.get('auto_liquid_cluster', validator=validation.any[boolean]) -%} {%- if cols is not none %} {%- if cols is string -%} {%- set cols = [cols] -%} {%- endif -%} CLUSTER BY ({{ cols | join(', ') }}) + {%- elif auto_cluster -%} + CLUSTER BY AUTO {%- endif %} {%- endmacro -%} {% macro apply_liquid_clustered_cols(target_relation) -%} {%- set cols = config.get('liquid_clustered_by', validator=validation.any[list, basestring]) -%} + {%- set auto_cluster = config.get('auto_liquid_cluster', validator=validation.any[boolean]) -%} {%- if cols is not none %} {%- if cols is string -%} {%- set cols = [cols] -%} {%- endif -%} {%- call statement('set_cluster_by_columns') -%} - ALTER {{ target_relation.type }} {{ target_relation }} CLUSTER BY ({{ cols | join(', ') }}) + ALTER {{ target_relation.type }} {{ target_relation.render() }} CLUSTER BY ({{ cols | join(', ') }}) + {%- endcall -%} + {%- elif auto_cluster -%} + {%- call statement('set_cluster_by_auto') -%} + ALTER {{ target_relation.type }} {{ target_relation.render() }} CLUSTER BY AUTO {%- endcall -%} {%- endif %} {%- endmacro -%} \ No newline at end of file diff --git a/dbt/include/databricks/macros/relations/materialized_view/alter.sql b/dbt/include/databricks/macros/relations/materialized_view/alter.sql index 41d9bed06..d406508d2 100644 --- a/dbt/include/databricks/macros/relations/materialized_view/alter.sql +++ b/dbt/include/databricks/macros/relations/materialized_view/alter.sql @@ -46,6 +46,6 @@ {% macro get_alter_mv_internal(relation, configuration_changes) %} {%- set refresh = configuration_changes.changes["refresh"] -%} -- Currently only schedule can be altered - ALTER MATERIALIZED VIEW {{ relation }} + ALTER MATERIALIZED VIEW {{ relation.render() }} {{ get_alter_sql_refresh_schedule(refresh.cron, refresh.time_zone_value, refresh.is_altered) -}} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/materialized_view/drop.sql b/dbt/include/databricks/macros/relations/materialized_view/drop.sql index f3774119d..4def47441 100644 --- a/dbt/include/databricks/macros/relations/materialized_view/drop.sql +++ b/dbt/include/databricks/macros/relations/materialized_view/drop.sql @@ -1,3 +1,3 @@ {% macro databricks__drop_materialized_view(relation) -%} - drop materialized view if exists {{ relation }} + drop materialized view if exists {{ relation.render() }} {%- endmacro %} diff --git a/dbt/include/databricks/macros/relations/materialized_view/refresh.sql b/dbt/include/databricks/macros/relations/materialized_view/refresh.sql index 10a8346be..9967eb21f 100644 --- a/dbt/include/databricks/macros/relations/materialized_view/refresh.sql +++ b/dbt/include/databricks/macros/relations/materialized_view/refresh.sql @@ -1,3 +1,3 @@ {% macro databricks__refresh_materialized_view(relation) -%} - refresh materialized view {{ relation }} + refresh materialized view {{ relation.render() }} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/optimize.sql b/dbt/include/databricks/macros/relations/optimize.sql index 79f001646..a6108709d 100644 --- a/dbt/include/databricks/macros/relations/optimize.sql +++ b/dbt/include/databricks/macros/relations/optimize.sql @@ -6,7 +6,7 @@ {%- if var('DATABRICKS_SKIP_OPTIMIZE', 'false')|lower != 'true' and var('databricks_skip_optimize', 'false')|lower != 'true' and config.get('file_format', 'delta') == 'delta' -%} - {%- if (config.get('zorder', False) or config.get('liquid_clustered_by', False)) -%} + {%- if (config.get('zorder', False) or config.get('liquid_clustered_by', False)) or config.get('auto_liquid_cluster', False) -%} {%- call statement('run_optimize_stmt') -%} {{ get_optimize_sql(relation) }} {%- endcall -%} @@ -17,8 +17,8 @@ {%- macro get_optimize_sql(relation) %} optimize {{ relation }} {%- if config.get('zorder', False) and config.get('file_format', 'delta') == 'delta' %} - {%- if config.get('liquid_clustered_by', False) %} - {{ exceptions.warn("Both zorder and liquid_clustered_by are set but they are incompatible. zorder will be ignored.") }} + {%- if config.get('liquid_clustered_by', False) or config.get('auto_liquid_cluster', False) %} + {{ exceptions.warn("Both zorder and liquid_clustering are set but they are incompatible. zorder will be ignored.") }} {%- else %} {%- set zorder = config.get('zorder', none) %} {# TODO: predicates here? WHERE ... #} diff --git a/dbt/include/databricks/macros/relations/streaming_table/drop.sql b/dbt/include/databricks/macros/relations/streaming_table/drop.sql index c8e0cd839..1cfc246a8 100644 --- a/dbt/include/databricks/macros/relations/streaming_table/drop.sql +++ b/dbt/include/databricks/macros/relations/streaming_table/drop.sql @@ -3,5 +3,5 @@ {%- endmacro %} {% macro default__drop_streaming_table(relation) -%} - drop table if exists {{ relation }} + drop table if exists {{ relation.render() }} {%- endmacro %} diff --git a/dbt/include/databricks/macros/relations/streaming_table/refresh.sql b/dbt/include/databricks/macros/relations/streaming_table/refresh.sql index 66b86f1f4..94c96d5cc 100644 --- a/dbt/include/databricks/macros/relations/streaming_table/refresh.sql +++ b/dbt/include/databricks/macros/relations/streaming_table/refresh.sql @@ -3,7 +3,7 @@ {%- endmacro %} {% macro databricks__refresh_streaming_table(relation, sql) -%} - create or refresh streaming table {{ relation }} + create or refresh streaming table {{ relation.render() }} as {{ sql }} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/table/create.sql b/dbt/include/databricks/macros/relations/table/create.sql index 9e74d57d6..b2aba2fec 100644 --- a/dbt/include/databricks/macros/relations/table/create.sql +++ b/dbt/include/databricks/macros/relations/table/create.sql @@ -5,9 +5,9 @@ {%- else -%} {%- set file_format = config.get('file_format', default='delta') -%} {% if file_format == 'delta' %} - create or replace table {{ relation }} + create or replace table {{ relation.render() }} {% else %} - create table {{ relation }} + create table {{ relation.render() }} {% endif %} {%- set contract_config = config.get('contract') -%} {% if contract_config and contract_config.enforced %} diff --git a/dbt/include/databricks/macros/relations/table/drop.sql b/dbt/include/databricks/macros/relations/table/drop.sql index 3a7d0ced0..7bce7cf46 100644 --- a/dbt/include/databricks/macros/relations/table/drop.sql +++ b/dbt/include/databricks/macros/relations/table/drop.sql @@ -1,3 +1,3 @@ {% macro databricks__drop_table(relation) -%} - drop table if exists {{ relation }} + drop table if exists {{ relation.render() }} {%- endmacro %} diff --git a/dbt/include/databricks/macros/relations/tags.sql b/dbt/include/databricks/macros/relations/tags.sql index 3467631df..fb39c3785 100644 --- a/dbt/include/databricks/macros/relations/tags.sql +++ b/dbt/include/databricks/macros/relations/tags.sql @@ -33,7 +33,7 @@ {%- endmacro -%} {% macro alter_set_tags(relation, tags) -%} - ALTER {{ relation.type }} {{ relation }} SET TAGS ( + ALTER {{ relation.type }} {{ relation.render() }} SET TAGS ( {% for tag in tags -%} '{{ tag }}' = '{{ tags[tag] }}' {%- if not loop.last %}, {% endif -%} {%- endfor %} @@ -41,7 +41,7 @@ {%- endmacro -%} {% macro alter_unset_tags(relation, tags) -%} - ALTER {{ relation.type }} {{ relation }} UNSET TAGS ( + ALTER {{ relation.type }} {{ relation.render() }} UNSET TAGS ( {% for tag in tags -%} '{{ tag }}' {%- if not loop.last %}, {%- endif %} {%- endfor %} diff --git a/dbt/include/databricks/macros/relations/tblproperties.sql b/dbt/include/databricks/macros/relations/tblproperties.sql index 34b6488f7..b11fd7b5c 100644 --- a/dbt/include/databricks/macros/relations/tblproperties.sql +++ b/dbt/include/databricks/macros/relations/tblproperties.sql @@ -17,7 +17,7 @@ {% set tblproperty_statment = databricks__tblproperties_clause(tblproperties) %} {% if tblproperty_statment %} {%- call statement('apply_tblproperties') -%} - ALTER {{ relation.type }} {{ relation }} SET {{ tblproperty_statment}} + ALTER {{ relation.type }} {{ relation.render() }} SET {{ tblproperty_statment}} {%- endcall -%} {% endif %} {%- endmacro -%} diff --git a/dbt/include/databricks/macros/relations/view/create.sql b/dbt/include/databricks/macros/relations/view/create.sql index 096e12de4..5399b4ef5 100644 --- a/dbt/include/databricks/macros/relations/view/create.sql +++ b/dbt/include/databricks/macros/relations/view/create.sql @@ -1,5 +1,5 @@ {% macro databricks__create_view_as(relation, sql) -%} - create or replace view {{ relation }} + create or replace view {{ relation.render() }} {% if config.persist_column_docs() -%} {% set model_columns = model.columns %} {% set query_columns = get_columns_in_query(sql) %} diff --git a/dbt/include/databricks/macros/relations/view/drop.sql b/dbt/include/databricks/macros/relations/view/drop.sql index aa199d760..9098c925f 100644 --- a/dbt/include/databricks/macros/relations/view/drop.sql +++ b/dbt/include/databricks/macros/relations/view/drop.sql @@ -1,3 +1,3 @@ {% macro databricks__drop_view(relation) -%} - drop view if exists {{ relation }} + drop view if exists {{ relation.render() }} {%- endmacro %} diff --git a/pyproject.toml b/pyproject.toml index d2f728d8a..c176f1ed1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,10 +65,10 @@ check-sdist = [ [tool.hatch.envs.default] dependencies = [ "dbt_common @ git+https://github.com/dbt-labs/dbt-common.git", - "dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git@main", + "dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-adapters", "dbt-core @ git+https://github.com/dbt-labs/dbt-core.git@main#subdirectory=core", "dbt-tests-adapter @ git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-tests-adapter", - "dbt-spark @ git+https://github.com/dbt-labs/dbt-spark.git@main", + "dbt-spark @ git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-spark", "pytest", "pytest-xdist", "pytest-dotenv", diff --git a/tests/functional/adapter/iceberg/test_iceberg_support.py b/tests/functional/adapter/iceberg/test_iceberg_support.py index b25a54af7..a52f4b97c 100644 --- a/tests/functional/adapter/iceberg/test_iceberg_support.py +++ b/tests/functional/adapter/iceberg/test_iceberg_support.py @@ -5,8 +5,7 @@ from tests.functional.adapter.iceberg import fixtures -# @pytest.mark.skip_profile("databricks_cluster") -@pytest.mark.skip("Skip for now as it is broken in prod") +@pytest.mark.skip_profile("databricks_cluster") class TestIcebergTables: @pytest.fixture(scope="class") def models(self): @@ -21,8 +20,7 @@ def test_iceberg_refs(self, project): assert len(run_results) == 3 -# @pytest.mark.skip_profile("databricks_cluster") -@pytest.mark.skip("Skip for now as it is broken in prod") +@pytest.mark.skip_profile("databricks_cluster") class TestIcebergSwap: @pytest.fixture(scope="class") def models(self): diff --git a/tests/functional/adapter/liquid_clustering/fixtures.py b/tests/functional/adapter/liquid_clustering/fixtures.py index 951210ea6..6c7302227 100644 --- a/tests/functional/adapter/liquid_clustering/fixtures.py +++ b/tests/functional/adapter/liquid_clustering/fixtures.py @@ -2,3 +2,8 @@ {{ config(materialized='incremental', liquid_clustered_by='id') }} select 1 as id, 'Joe' as name """ + +auto_liquid_cluster_sql = """ +{{ config(materialized='incremental', auto_liquid_cluster=true) }} +select 1 as id, 'Joe' as name +""" diff --git a/tests/functional/adapter/liquid_clustering/test_liquid_clustering.py b/tests/functional/adapter/liquid_clustering/test_liquid_clustering.py index a9cc0ee09..45c7dfe2c 100644 --- a/tests/functional/adapter/liquid_clustering/test_liquid_clustering.py +++ b/tests/functional/adapter/liquid_clustering/test_liquid_clustering.py @@ -15,3 +15,16 @@ def models(self): def test_liquid_clustering(self, project): _, logs = util.run_dbt_and_capture(["--debug", "run"]) assert "optimize" in logs + + +class TestAutoLiquidClustering: + @pytest.fixture(scope="class") + def models(self): + return { + "liquid_clustering.sql": fixtures.liquid_cluster_sql, + } + + @pytest.mark.skip_profile("databricks_uc_cluster", "databricks_cluster") + def test_liquid_clustering(self, project): + _, logs = util.run_dbt_and_capture(["--debug", "run"]) + assert "optimize" in logs diff --git a/tests/functional/adapter/python_model/fixtures.py b/tests/functional/adapter/python_model/fixtures.py index 127bcf74e..d1a4dd4a9 100644 --- a/tests/functional/adapter/python_model/fixtures.py +++ b/tests/functional/adapter/python_model/fixtures.py @@ -42,6 +42,32 @@ def model(dbt, spark): identifier: source """ +serverless_schema_with_environment = """version: 2 + +models: + - name: my_versioned_sql_model + versions: + - v: 1 + - name: my_python_model + config: + submission_method: serverless_cluster + create_notebook: true + environment_key: "test_key" + environment_dependencies: ["requests"] + +sources: + - name: test_source + loader: custom + schema: "{{ var(env_var('DBT_TEST_SCHEMA_NAME_VARIABLE')) }}" + quoting: + identifier: True + tags: + - my_test_source_tag + tables: + - name: test_table + identifier: source +""" + workflow_schema = """version: 2 models: diff --git a/tests/functional/adapter/python_model/test_python_model.py b/tests/functional/adapter/python_model/test_python_model.py index 726791dfa..ce490c0f3 100644 --- a/tests/functional/adapter/python_model/test_python_model.py +++ b/tests/functional/adapter/python_model/test_python_model.py @@ -113,6 +113,21 @@ def models(self): } +@pytest.mark.python +# @pytest.mark.skip_profile("databricks_cluster", "databricks_uc_cluster") +@pytest.mark.skip("Not available in Databricks yet") +class TestServerlessClusterWithEnvironment(BasePythonModelTests): + @pytest.fixture(scope="class") + def models(self): + return { + "schema.yml": override_fixtures.serverless_schema_with_environment, + "my_sql_model.sql": fixtures.basic_sql, + "my_versioned_sql_model_v1.sql": fixtures.basic_sql, + "my_python_model.py": fixtures.basic_python, + "second_sql_model.sql": fixtures.second_sql, + } + + @pytest.mark.python @pytest.mark.external @pytest.mark.skip_profile("databricks_cluster", "databricks_uc_sql_endpoint") diff --git a/tests/functional/adapter/simple_snapshot/test_new_record_mode.py b/tests/functional/adapter/simple_snapshot/test_new_record_mode.py new file mode 100644 index 000000000..6b436a311 --- /dev/null +++ b/tests/functional/adapter/simple_snapshot/test_new_record_mode.py @@ -0,0 +1,74 @@ +import pytest + +from dbt.tests.adapter.simple_snapshot.new_record_mode import ( + _delete_sql, + _invalidate_sql, + _ref_snapshot_sql, + _seed_new_record_mode, + _snapshot_actual_sql, + _snapshots_yml, + _update_sql, +) +from dbt.tests.util import check_relations_equal, run_dbt + + +class TestDatabricksSnapshotNewRecordMode: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": _snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": _snapshots_yml, + "ref_snapshot.sql": _ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def seed_new_record_mode(self): + return _seed_new_record_mode + + @pytest.fixture(scope="class") + def invalidate_sql_1(self): + return _invalidate_sql.split(";", 1)[0].replace("BEGIN", "") + + @pytest.fixture(scope="class") + def invalidate_sql_2(self): + return _invalidate_sql.split(";", 1)[1].replace("END", "").replace(";", "") + + @pytest.fixture(scope="class") + def update_sql(self): + return _update_sql.replace("text", "string") + + @pytest.fixture(scope="class") + def delete_sql(self): + return _delete_sql + + def test_snapshot_new_record_mode( + self, project, seed_new_record_mode, invalidate_sql_1, invalidate_sql_2, update_sql + ): + for sql in ( + seed_new_record_mode.replace("text", "string") + .replace("TEXT", "STRING") + .replace("BEGIN", "") + .replace("END;", "") + .replace(" WITHOUT TIME ZONE", "") + .split(";") + ): + project.run_sql(sql) + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + project.run_sql(_delete_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 diff --git a/tests/functional/adapter/simple_snapshot/test_snapshot.py b/tests/functional/adapter/simple_snapshot/test_snapshot.py index 56186e4ef..a77e9befa 100644 --- a/tests/functional/adapter/simple_snapshot/test_snapshot.py +++ b/tests/functional/adapter/simple_snapshot/test_snapshot.py @@ -32,6 +32,17 @@ class TestSnapshotCheck(BaseSnapshotCheck): pass +@pytest.mark.skip_profile("databricks_cluster") +class TestSnapshotIceberg(BaseSnapshotCheck): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "snapshots": { + "+table_format": "iceberg", + } + } + + class TestSnapshotPersistDocs: @pytest.fixture(scope="class") def models(self): diff --git a/tests/functional/adapter/simple_snapshot/test_various_configs.py b/tests/functional/adapter/simple_snapshot/test_various_configs.py new file mode 100644 index 000000000..18b82de00 --- /dev/null +++ b/tests/functional/adapter/simple_snapshot/test_various_configs.py @@ -0,0 +1,345 @@ +import datetime + +import pytest +from agate import Table + +from dbt.tests.adapter.simple_snapshot.fixtures import ( + create_multi_key_seed_sql, + create_multi_key_snapshot_expected_sql, + create_seed_sql, + create_snapshot_expected_sql, + model_seed_sql, + populate_multi_key_snapshot_expected_sql, + populate_snapshot_expected_sql, + populate_snapshot_expected_valid_to_current_sql, + ref_snapshot_sql, + seed_insert_sql, + seed_multi_key_insert_sql, + snapshot_actual_sql, + snapshots_multi_key_yml, + snapshots_no_column_names_yml, + snapshots_valid_to_current_yml, + snapshots_yml, + update_multi_key_sql, + update_sql, + update_with_current_sql, +) +from dbt.tests.util import ( + check_relations_equal, + get_manifest, + run_dbt, + run_dbt_and_capture, + run_sql_with_adapter, + update_config_file, +) + + +def text_replace(input: str) -> str: + return input.replace("TEXT", "STRING").replace("text", "string") + + +create_snapshot_expected_sql = text_replace(create_snapshot_expected_sql) +populate_snapshot_expected_sql = text_replace(populate_snapshot_expected_sql) +populate_snapshot_expected_valid_to_current_sql = text_replace( + populate_snapshot_expected_valid_to_current_sql +) +update_with_current_sql = text_replace(update_with_current_sql) +create_multi_key_snapshot_expected_sql = text_replace(create_multi_key_snapshot_expected_sql) +populate_multi_key_snapshot_expected_sql = text_replace(populate_multi_key_snapshot_expected_sql) +update_sql = text_replace(update_sql) +update_multi_key_sql = text_replace(update_multi_key_sql) + +invalidate_sql_1 = """ +-- update records 11 - 21. Change email and updated_at field +update {schema}.seed set + updated_at = updated_at + interval '1 hour', + email = case when id = 20 then 'pfoxj@creativecommons.org' else 'new_' || email end +where id >= 10 and id <= 20 +""" + +invalidate_sql_2 = """ +-- invalidate records 11 - 21 +update {schema}.snapshot_expected set + test_valid_to = updated_at + interval '1 hour' +where id >= 10 and id <= 20; +""" + +invalidate_multi_key_sql_1 = """ +-- update records 11 - 21. Change email and updated_at field +update {schema}.seed set + updated_at = updated_at + interval '1 hour', + email = case when id1 = 20 then 'pfoxj@creativecommons.org' else 'new_' || email end +where id1 >= 10 and id1 <= 20; +""" + +invalidate_multi_key_sql_2 = """ +-- invalidate records 11 - 21 +update {schema}.snapshot_expected set + test_valid_to = updated_at + interval '1 hour' +where id1 >= 10 and id1 <= 20; +""" + + +class BaseSnapshotColumnNames: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_snapshot_column_names(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class BaseSnapshotColumnNamesFromDbtProject: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_no_column_names_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + } + } + } + + def test_snapshot_column_names_from_project(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class BaseSnapshotInvalidColumnNames: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_no_column_names_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + } + } + } + + def test_snapshot_invalid_column_names(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + manifest = get_manifest(project.project_root) + snapshot_node = manifest.nodes["snapshot.test.snapshot_actual"] + snapshot_node.config.snapshot_meta_column_names == { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + # Change snapshot_meta_columns and look for an error + different_columns = { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_updated_at": "test_updated_at", + } + } + } + } + update_config_file(different_columns, "dbt_project.yml") + + results, log_output = run_dbt_and_capture(["snapshot"], expect_pass=False) + assert len(results) == 1 + assert "Compilation Error in snapshot snapshot_actual" in log_output + assert "Snapshot target is missing configured columns" in log_output + + +class BaseSnapshotDbtValidToCurrent: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_valid_to_current_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_valid_to_current(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_valid_to_current_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + original_snapshot: Table = run_sql_with_adapter( + project.adapter, + "select id, test_scd_id, test_valid_to from {schema}.snapshot_actual", + "all", + ) + assert original_snapshot[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + original_row = list( + filter(lambda x: x[1] == "61ecd07d17b8a4acb57d115eebb0e2c9", original_snapshot) + ) + assert original_row[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_with_current_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + updated_snapshot: Table = run_sql_with_adapter( + project.adapter, + "select id, test_scd_id, test_valid_to from {schema}.snapshot_actual", + "all", + ) + print(updated_snapshot) + assert updated_snapshot[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + # Original row that was updated now has a non-current (2099/12/31) date + original_row = list( + filter(lambda x: x[1] == "61ecd07d17b8a4acb57d115eebb0e2c9", updated_snapshot) + ) + assert original_row[0][2] == datetime.datetime( + 2016, 8, 20, 16, 44, 49, tzinfo=datetime.timezone.utc + ) + updated_row = list( + filter(lambda x: x[1] == "af1f803f2179869aeacb1bfe2b23c1df", updated_snapshot) + ) + + # Updated row has a current date + assert updated_row[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +# This uses snapshot_meta_column_names, yaml-only snapshot def, +# and multiple keys +class BaseSnapshotMultiUniqueKey: + @pytest.fixture(scope="class") + def models(self): + return { + "seed.sql": model_seed_sql, + "snapshots.yml": snapshots_multi_key_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_multi_column_unique_key(self, project): + project.run_sql(create_multi_key_seed_sql) + project.run_sql(create_multi_key_snapshot_expected_sql) + project.run_sql(seed_multi_key_insert_sql) + project.run_sql(populate_multi_key_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_multi_key_sql_1) + project.run_sql(invalidate_multi_key_sql_2) + project.run_sql(update_multi_key_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class TestDatabricksSnapshotColumnNames(BaseSnapshotColumnNames): + pass + + +class TestDatabricksSnapshotColumnNamesFromDbtProject(BaseSnapshotColumnNamesFromDbtProject): + pass + + +class TestDatabricksSnapshotInvalidColumnNames(BaseSnapshotInvalidColumnNames): + pass + + +class TestDatabricksSnapshotDbtValidToCurrent(BaseSnapshotDbtValidToCurrent): + pass + + +class TestDatabricksSnapshotMultiUniqueKey(BaseSnapshotMultiUniqueKey): + pass diff --git a/tests/unit/events/test_connection_events.py b/tests/unit/events/test_connection_events.py index 81a30c1dc..6c005fb64 100644 --- a/tests/unit/events/test_connection_events.py +++ b/tests/unit/events/test_connection_events.py @@ -2,7 +2,6 @@ from dbt.adapters.databricks.events.connection_events import ( ConnectionAcquire, - ConnectionCloseError, ConnectionEvent, ) @@ -30,16 +29,6 @@ def test_connection_event__with_id(self): assert str(event) == "Connection(session-id=1234) - This is a test" -class TestConnectionCloseError: - def test_connection_close_error(self): - e = Exception("This is an exception") - event = ConnectionCloseError(None, e) - assert ( - str(event) == "Connection(session-id=Unknown) " - "- Exception while trying to close connection: This is an exception" - ) - - class TestConnectionAcquire: def test_connection_acquire__missing_data(self): event = ConnectionAcquire(None, None, None, (0, 0)) diff --git a/tests/unit/events/test_cursor_events.py b/tests/unit/events/test_cursor_events.py deleted file mode 100644 index e492415a6..000000000 --- a/tests/unit/events/test_cursor_events.py +++ /dev/null @@ -1,41 +0,0 @@ -from unittest.mock import Mock - -from dbt.adapters.databricks.events.cursor_events import CursorCloseError, CursorEvent - - -class CursorTestEvent(CursorEvent): - def __init__(self, cursor): - super().__init__(cursor, "This is a test") - - -class TestCursorEvents: - def test_cursor_event__no_cursor(self): - event = CursorTestEvent(None) - assert str(event) == "Cursor(session-id=Unknown, command-id=Unknown) - This is a test" - - def test_cursor_event__no_ids(self): - mock = Mock() - mock.connection = None - mock.active_result_set = None - event = CursorTestEvent(mock) - assert str(event) == "Cursor(session-id=Unknown, command-id=Unknown) - This is a test" - - def test_cursor_event__with_ids(self): - mock = Mock() - mock.connection.get_session_id_hex.return_value = "1234" - mock.active_result_set.command_id.operationId.guid = (1234).to_bytes(16, "big") - event = CursorTestEvent(mock) - assert ( - str(event) == "Cursor(session-id=1234, command-id=00000000-0000-0000-0000-0000000004d2)" - " - This is a test" - ) - - -class TestCursorCloseError: - def test_cursor_close_error(self): - e = Exception("This is an exception") - event = CursorCloseError(None, e) - assert ( - str(event) == "Cursor(session-id=Unknown, command-id=Unknown) " - "- Exception while trying to close cursor: This is an exception" - ) diff --git a/tests/unit/macros/relations/test_table_macros.py b/tests/unit/macros/relations/test_table_macros.py index 9c5635131..f05dc9d40 100644 --- a/tests/unit/macros/relations/test_table_macros.py +++ b/tests/unit/macros/relations/test_table_macros.py @@ -180,6 +180,16 @@ def test_macros_create_table_as_liquid_clusters(self, config, template_bundle): assert sql == expected + def test_macros_create_table_as_liquid_cluster_auto(self, config, template_bundle): + config["auto_liquid_cluster"] = True + sql = self.render_create_table_as(template_bundle) + expected = ( + f"create or replace table {template_bundle.relation} using" + " delta CLUSTER BY AUTO as select 1" + ) + + assert sql == expected + def test_macros_create_table_as_comment(self, config, template_bundle): config["persist_docs"] = {"relation": True} template_bundle.context["model"].description = "Description Test" diff --git a/tests/unit/python/test_python_job_support.py b/tests/unit/python/test_python_job_support.py index 41f480413..fb996efa4 100644 --- a/tests/unit/python/test_python_job_support.py +++ b/tests/unit/python/test_python_job_support.py @@ -146,8 +146,16 @@ def run_name(self, parsed_model): parsed_model.config.additional_libs = [] return run_name + @pytest.fixture + def environment_key(self, parsed_model): + environment_key = "test_key" + parsed_model.config.environment_key = environment_key + parsed_model.config.environment_dependencies = ["requests"] + return environment_key + def test_compile__empty_configs(self, client, permission_builder, parsed_model, run_name): parsed_model.config.python_job_config.dict.return_value = {} + parsed_model.config.environment_key = None compiler = PythonJobConfigCompiler(client, permission_builder, parsed_model, {}) permission_builder.build_job_permissions.return_value = [] details = compiler.compile("path") @@ -162,7 +170,9 @@ def test_compile__empty_configs(self, client, permission_builder, parsed_model, } assert details.additional_job_config == {} - def test_compile__nonempty_configs(self, client, permission_builder, parsed_model, run_name): + def test_compile__nonempty_configs( + self, client, permission_builder, parsed_model, run_name, environment_key + ): parsed_model.config.packages = ["foo"] parsed_model.config.index_url = None parsed_model.config.python_job_config.dict.return_value = {"foo": "bar"} @@ -176,6 +186,7 @@ def test_compile__nonempty_configs(self, client, permission_builder, parsed_mode details = compiler.compile("path") assert details.run_name == run_name assert details.job_spec == { + "environment_key": environment_key, "task_key": "inner_notebook", "notebook_task": { "notebook_path": "path", @@ -185,4 +196,12 @@ def test_compile__nonempty_configs(self, client, permission_builder, parsed_mode "access_control_list": [{"user_name": "user", "permission_level": "IS_OWNER"}], "queue": {"enabled": True}, } - assert details.additional_job_config == {"foo": "bar"} + assert details.additional_job_config == { + "foo": "bar", + "environments": [ + { + "environment_key": environment_key, + "spec": {"client": "2", "dependencies": ["requests"]}, + } + ], + } diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 78ae12cbb..f8c3b4a17 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -11,8 +11,6 @@ from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.credentials import ( CATALOG_KEY_IN_SESSION_PROPERTIES, - DBT_DATABRICKS_HTTP_SESSION_HEADERS, - DBT_DATABRICKS_INVOCATION_ENV, ) from dbt.adapters.databricks.impl import get_identifier_list_string from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType @@ -79,8 +77,7 @@ def test_two_catalog_settings(self): ) expected_message = ( - "Got duplicate keys: (`databricks.catalog` in session_properties)" - ' all map to "database"' + 'Got duplicate keys: (`databricks.catalog` in session_properties) all map to "database"' ) assert expected_message in str(excinfo.value) @@ -114,7 +111,10 @@ def test_invalid_custom_user_agent(self): with pytest.raises(DbtValidationError) as excinfo: config = self._get_config() adapter = DatabricksAdapter(config, get_context("spawn")) - with patch.dict("os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "(Some-thing)"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_invocation_env", + return_value="(Some-thing)", + ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -125,11 +125,12 @@ def test_custom_user_agent(self): adapter = DatabricksAdapter(config, get_context("spawn")) with patch( - "dbt.adapters.databricks.connections.dbsql.connect", + "dbt.adapters.databricks.handle.dbsql.connect", new=self._connect_func(expected_invocation_env="databricks-workflows"), ): - with patch.dict( - "os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "databricks-workflows"} + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_invocation_env", + return_value="databricks-workflows", ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -187,12 +188,12 @@ def _test_environment_http_headers( adapter = DatabricksAdapter(config, get_context("spawn")) with patch( - "dbt.adapters.databricks.connections.dbsql.connect", + "dbt.adapters.databricks.handle.dbsql.connect", new=self._connect_func(expected_http_headers=expected_http_headers), ): - with patch.dict( - "os.environ", - **{DBT_DATABRICKS_HTTP_SESSION_HEADERS: http_headers_str}, + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_http_session_headers", + return_value=http_headers_str, ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -204,7 +205,7 @@ def test_oauth_settings(self): adapter = DatabricksAdapter(config, get_context("spawn")) with patch( - "dbt.adapters.databricks.connections.dbsql.connect", + "dbt.adapters.databricks.handle.dbsql.connect", new=self._connect_func(expected_no_token=True), ): connection = adapter.acquire_connection("dummy") @@ -217,7 +218,7 @@ def test_client_creds_settings(self): adapter = DatabricksAdapter(config, get_context("spawn")) with patch( - "dbt.adapters.databricks.connections.dbsql.connect", + "dbt.adapters.databricks.handle.dbsql.connect", new=self._connect_func(expected_client_creds=True), ): connection = adapter.acquire_connection("dummy") @@ -266,6 +267,7 @@ def connect( assert http_headers is None else: assert http_headers == expected_http_headers + return Mock() return connect @@ -276,7 +278,7 @@ def _test_databricks_sql_connector_connection(self, connect): config = self._get_config() adapter = DatabricksAdapter(config, get_context("spawn")) - with patch("dbt.adapters.databricks.connections.dbsql.connect", new=connect): + with patch("dbt.adapters.databricks.handle.dbsql.connect", new=connect): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -300,7 +302,7 @@ def _test_databricks_sql_connector_catalog_connection(self, connect): config = self._get_config() adapter = DatabricksAdapter(config, get_context("spawn")) - with patch("dbt.adapters.databricks.connections.dbsql.connect", new=connect): + with patch("dbt.adapters.databricks.handle.dbsql.connect", new=connect): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -327,7 +329,7 @@ def _test_databricks_sql_connector_http_header_connection(self, http_headers, co config = self._get_config(connection_parameters={"http_headers": http_headers}) adapter = DatabricksAdapter(config, get_context("spawn")) - with patch("dbt.adapters.databricks.connections.dbsql.connect", new=connect): + with patch("dbt.adapters.databricks.handle.dbsql.connect", new=connect): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -912,7 +914,10 @@ def test_describe_table_extended_2048_char_limit(self): assert get_identifier_list_string(table_names) == "|".join(table_names) # If environment variable is set, then limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # Long list of table names is capped assert get_identifier_list_string(table_names) == "*" @@ -941,7 +946,10 @@ def test_describe_table_extended_should_limit(self): table_names = set([f"customers_{i}" for i in range(200)]) # If environment variable is set, then limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # Long list of table names is capped assert get_identifier_list_string(table_names) == "*" @@ -954,7 +962,10 @@ def test_describe_table_extended_may_limit(self): table_names = set([f"customers_{i}" for i in range(200)]) # If environment variable is set, then we may limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # But a short list of table names is not capped assert get_identifier_list_string(list(table_names)[:5]) == "|".join( list(table_names)[:5] diff --git a/tests/unit/test_compute_config.py b/tests/unit/test_compute_config.py deleted file mode 100644 index 994d4ae9a..000000000 --- a/tests/unit/test_compute_config.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import Mock - -import pytest -from dbt_common.exceptions import DbtRuntimeError - -from dbt.adapters.databricks import connections -from dbt.adapters.databricks.credentials import DatabricksCredentials - - -class TestDatabricksConnectionHTTPPath: - """Test the various cases for determining a specified warehouse.""" - - @pytest.fixture(scope="class") - def err_msg(self): - return ( - "Compute resource foo does not exist or does not specify http_path," - " relation: a_relation" - ) - - @pytest.fixture(scope="class") - def path(self): - return "my_http_path" - - @pytest.fixture - def creds(self, path): - return DatabricksCredentials(http_path=path) - - @pytest.fixture - def node(self): - n = Mock() - n.config = {} - n.relation_name = "a_relation" - return n - - def test_get_http_path__empty(self, path, creds): - assert connections._get_http_path(None, creds) == path - - def test_get_http_path__no_compute(self, node, path, creds): - assert connections._get_http_path(node, creds) == path - - def test_get_http_path__missing_compute(self, node, creds, err_msg): - node.config["databricks_compute"] = "foo" - with pytest.raises(DbtRuntimeError) as exc: - connections._get_http_path(node, creds) - - assert err_msg in str(exc.value) - - def test_get_http_path__empty_compute(self, node, creds, err_msg): - node.config["databricks_compute"] = "foo" - creds.compute = {"foo": {}} - with pytest.raises(DbtRuntimeError) as exc: - connections._get_http_path(node, creds) - - assert err_msg in str(exc.value) - - def test_get_http_path__matching_compute(self, node, creds): - node.config["databricks_compute"] = "foo" - creds.compute = {"foo": {"http_path": "alternate_path"}} - assert "alternate_path" == connections._get_http_path(node, creds) diff --git a/tests/unit/test_handle.py b/tests/unit/test_handle.py new file mode 100644 index 000000000..013260630 --- /dev/null +++ b/tests/unit/test_handle.py @@ -0,0 +1,211 @@ +import sys +from decimal import Decimal +from unittest.mock import Mock + +import pytest +from databricks.sql.client import Cursor +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.contracts.connection import AdapterResponse +from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils + + +class TestSqlUtils: + @pytest.mark.parametrize( + "bindings, expected", [(None, None), ([1], [1]), ([1, Decimal(0.73)], [1, 0.73])] + ) + def test_translate_bindings(self, bindings, expected): + assert SqlUtils.translate_bindings(bindings) == expected + + @pytest.mark.parametrize( + "sql, expected", [(" select 1; ", "select 1"), ("select 1", "select 1")] + ) + def test_clean_sql(self, sql, expected): + assert SqlUtils.clean_sql(sql) == expected + + @pytest.mark.parametrize("result, expected", [("14.x", (14, sys.maxsize)), ("12.1", (12, 1))]) + def test_extract_dbr_version(self, result, expected): + assert SqlUtils.extract_dbr_version(result) == expected + + def test_extract_dbr_version__invalid(self): + with pytest.raises(DbtRuntimeError): + SqlUtils.extract_dbr_version("foo") + + +class TestCursorWrapper: + @pytest.fixture + def cursor(self): + return Mock() + + def test_description(self, cursor): + cursor.description = [("foo", "bar")] + wrapper = CursorWrapper(cursor) + assert wrapper.description == [("foo", "bar")] + + def test_cancel__closed(self, cursor): + wrapper = CursorWrapper(cursor) + wrapper.open = False + wrapper.cancel() + cursor.cancel.assert_not_called() + + def test_cancel__open_no_result_set(self, cursor): + wrapper = CursorWrapper(cursor) + cursor.active_result_set = None + wrapper.cancel() + assert wrapper.open is False + + def test_cancel__open_with_result_set(self, cursor): + wrapper = CursorWrapper(cursor) + wrapper.cancel() + cursor.cancel.assert_called_once() + + def test_cancel__error_cancelling(self, cursor): + cursor.cancel.side_effect = Exception("foo") + wrapper = CursorWrapper(cursor) + wrapper.cancel() + cursor.cancel.assert_called_once() + + def test_closed__closed(self, cursor): + wrapper = CursorWrapper(cursor) + wrapper.open = False + wrapper.close() + cursor.close.assert_not_called() + + def test_closed__open(self, cursor): + wrapper = CursorWrapper(cursor) + cursor.active_result_set = None + wrapper.close() + assert wrapper.open is False + + def test_close__error_closing(self, cursor): + cursor.close.side_effect = Exception("foo") + wrapper = CursorWrapper(cursor) + wrapper.close() + cursor.close.assert_called_once() + + def test_fetchone(self, cursor): + cursor.fetchone.return_value = [("foo", "bar")] + wrapper = CursorWrapper(cursor) + assert wrapper.fetchone() == [("foo", "bar")] + + def test_fetchall(self, cursor): + cursor.fetchall.return_value = [("foo", "bar")] + wrapper = CursorWrapper(cursor) + assert wrapper.fetchall() == [("foo", "bar")] + + def test_fetchmany(self, cursor): + cursor.fetchmany.return_value = [("foo", "bar")] + wrapper = CursorWrapper(cursor) + assert wrapper.fetchmany(1) == [("foo", "bar")] + + def test_get_response__no_query_id(self, cursor): + cursor.query_id = None + wrapper = CursorWrapper(cursor) + assert wrapper.get_response() == AdapterResponse("OK", query_id="N/A") + + def test_get_response__with_query_id(self, cursor): + cursor.query_id = "id" + wrapper = CursorWrapper(cursor) + assert wrapper.get_response() == AdapterResponse("OK", query_id="id") + + def test_with__no_exception(self, cursor): + with CursorWrapper(cursor) as c: + c.fetchone() + cursor.fetchone.assert_called_once() + cursor.close.assert_called_once() + + def test_with__exception(self, cursor): + cursor.fetchone.side_effect = Exception("foo") + with pytest.raises(Exception, match="foo"): + with CursorWrapper(cursor) as c: + c.fetchone() + cursor.fetchone.assert_called_once() + cursor.close.assert_called_once() + + +class TestDatabricksHandle: + @pytest.fixture + def conn(self): + return Mock() + + @pytest.fixture + def cursor(self): + return Mock() + + def test_safe_execute__closed(self, conn): + handle = DatabricksHandle(conn, True) + handle.open = False + with pytest.raises(DbtRuntimeError, match="Attempting to execute on a closed connection"): + handle._safe_execute(Mock()) + + def test_safe_execute__with_cursor(self, conn, cursor): + new_cursor = Mock() + + def f(_: Cursor) -> Cursor: + return new_cursor + + handle = DatabricksHandle(conn, True) + handle._cursor = cursor + assert handle._safe_execute(f)._cursor == new_cursor + assert handle._cursor._cursor == new_cursor + cursor.close.assert_called_once() + + def test_safe_execute__without_cursor(self, conn): + new_cursor = Mock() + + def f(_: Cursor) -> Cursor: + return new_cursor + + handle = DatabricksHandle(conn, True) + assert handle._safe_execute(f)._cursor == new_cursor + assert handle._cursor._cursor == new_cursor + + def test_cancel__closed(self, conn): + handle = DatabricksHandle(conn, True) + handle.open = False + handle.cancel() + conn.close.assert_not_called() + + def test_cancel__open_no_cursor(self, conn): + handle = DatabricksHandle(conn, True) + handle.cancel() + conn.close.assert_called_once() + + def test_cancel__open_cursor(self, conn, cursor): + handle = DatabricksHandle(conn, True) + handle._cursor = cursor + handle.cancel() + cursor.cancel.assert_called_once() + conn.close.assert_called_once() + + def test_cancel__open_raising_exception(self, conn): + conn.close.side_effect = Exception("foo") + handle = DatabricksHandle(conn, True) + handle.cancel() + conn.close.assert_called_once() + + def test_close__closed(self, conn): + handle = DatabricksHandle(conn, True) + handle.open = False + handle.close() + conn.close.assert_not_called() + + def test_close__open_no_cursor(self, conn): + handle = DatabricksHandle(conn, True) + handle.close() + conn.close.assert_called_once() + + def test_close__open_cursor(self, conn, cursor): + handle = DatabricksHandle(conn, True) + handle._cursor = cursor + handle.close() + cursor.close.assert_called_once() + conn.close.assert_called_once() + + def test_close__open_raising_exception(self, conn, cursor): + conn.close.side_effect = Exception("foo") + handle = DatabricksHandle(conn, True) + handle._cursor = cursor + handle.close() + cursor.close.assert_called_once() + conn.close.assert_called_once() diff --git a/tests/unit/test_idle_config.py b/tests/unit/test_idle_config.py deleted file mode 100644 index de3545680..000000000 --- a/tests/unit/test_idle_config.py +++ /dev/null @@ -1,240 +0,0 @@ -import pytest -from dbt_common.exceptions import DbtRuntimeError - -from dbt.adapters.databricks import connections -from dbt.adapters.databricks.credentials import DatabricksCredentials -from dbt.contracts.graph import model_config, nodes - - -class TestDatabricksConnectionMaxIdleTime: - """Test the various cases for determining a specified warehouse.""" - - errMsg = ( - "Compute resource foo does not exist or does not specify http_path, " "relation: a_relation" - ) - - def test_get_max_idle_default(self): - creds = DatabricksCredentials() - - # No node and nothing specified in creds - time = connections._get_max_idle_time(None, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - node = nodes.ModelNode( - relation_name="a_relation", - database="database", - schema="schema", - name="node_name", - resource_type="model", - package_name="package", - path="path", - original_file_path="orig_path", - unique_id="uniqueID", - fqn=[], - alias="alias", - checksum=None, - ) - - # node has no configuration so should get back default - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - # empty configuration should return default - node.config = model_config.ModelConfig() - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - # node with no extras in configuration should return default - node.config._extra = {} - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - # node that specifies a compute with no corresponding definition should return default - node.config._extra["databricks_compute"] = "foo" - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - creds.compute = {} - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - - # if alternate compute doesn't specify a max time should return default - creds.compute = {"foo": {}} - time = connections._get_max_idle_time(node, creds) - assert connections.DEFAULT_MAX_IDLE_TIME == time - # with self.assertRaisesRegex( - # dbt.exceptions.DbtRuntimeError, - # self.errMsg, - # ): - # connections._get_http_path(node, creds) - - # creds.compute = {"foo": {"http_path": "alternate_path"}} - # path = connections._get_http_path(node, creds) - # self.assertEqual("alternate_path", path) - - def test_get_max_idle_creds(self): - creds_idle_time = 77 - creds = DatabricksCredentials(connect_max_idle=creds_idle_time) - - # No node so value should come from creds - time = connections._get_max_idle_time(None, creds) - assert creds_idle_time == time - - node = nodes.ModelNode( - relation_name="a_relation", - database="database", - schema="schema", - name="node_name", - resource_type="model", - package_name="package", - path="path", - original_file_path="orig_path", - unique_id="uniqueID", - fqn=[], - alias="alias", - checksum=None, - ) - - # node has no configuration so should get value from creds - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - # empty configuration should get value from creds - node.config = model_config.ModelConfig() - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - # node with no extras in configuration should get value from creds - node.config._extra = {} - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - # node that specifies a compute with no corresponding definition should get value from creds - node.config._extra["databricks_compute"] = "foo" - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - creds.compute = {} - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - # if alternate compute doesn't specify a max time should get value from creds - creds.compute = {"foo": {}} - time = connections._get_max_idle_time(node, creds) - assert creds_idle_time == time - - def test_get_max_idle_compute(self): - creds_idle_time = 88 - compute_idle_time = 77 - creds = DatabricksCredentials(connect_max_idle=creds_idle_time) - creds.compute = {"foo": {"connect_max_idle": compute_idle_time}} - - node = nodes.SnapshotNode( - config=None, - relation_name="a_relation", - database="database", - schema="schema", - name="node_name", - resource_type="model", - package_name="package", - path="path", - original_file_path="orig_path", - unique_id="uniqueID", - fqn=[], - alias="alias", - checksum=None, - ) - - node.config = model_config.SnapshotConfig() - node.config._extra = {"databricks_compute": "foo"} - - time = connections._get_max_idle_time(node, creds) - assert compute_idle_time == time - - def test_get_max_idle_invalid(self): - creds_idle_time = "foo" - compute_idle_time = "bar" - creds = DatabricksCredentials(connect_max_idle=creds_idle_time) - creds.compute = {"alternate_compute": {"connect_max_idle": compute_idle_time}} - - node = nodes.SnapshotNode( - config=None, - relation_name="a_relation", - database="database", - schema="schema", - name="node_name", - resource_type="model", - package_name="package", - path="path", - original_file_path="orig_path", - unique_id="uniqueID", - fqn=[], - alias="alias", - checksum=None, - ) - - node.config = model_config.SnapshotConfig() - - with pytest.raises(DbtRuntimeError) as info: - connections._get_max_idle_time(node, creds) - assert ( - f"{creds_idle_time} is not a valid value for connect_max_idle. " - "Must be a number of seconds." - ) in str(info.value) - - node.config._extra["databricks_compute"] = "alternate_compute" - with pytest.raises(DbtRuntimeError) as info: - connections._get_max_idle_time(node, creds) - assert ( - f"{compute_idle_time} is not a valid value for connect_max_idle. " - "Must be a number of seconds." - ) in str(info.value) - - creds.compute["alternate_compute"]["connect_max_idle"] = "1.2.3" - with pytest.raises(DbtRuntimeError) as info: - connections._get_max_idle_time(node, creds) - assert ( - "1.2.3 is not a valid value for connect_max_idle. " "Must be a number of seconds." - ) in str(info.value) - - creds.compute["alternate_compute"]["connect_max_idle"] = "1,002.3" - with pytest.raises(DbtRuntimeError) as info: - connections._get_max_idle_time(node, creds) - assert ( - "1,002.3 is not a valid value for connect_max_idle. " "Must be a number of seconds." - ) in str(info.value) - - def test_get_max_idle_simple_string_conversion(self): - creds_idle_time = "12" - compute_idle_time = "34" - creds = DatabricksCredentials(connect_max_idle=creds_idle_time) - creds.compute = {"alternate_compute": {"connect_max_idle": compute_idle_time}} - - node = nodes.SnapshotNode( - config=None, - relation_name="a_relation", - database="database", - schema="schema", - name="node_name", - resource_type="model", - package_name="package", - path="path", - original_file_path="orig_path", - unique_id="uniqueID", - fqn=[], - alias="alias", - checksum=None, - ) - - node.config = model_config.SnapshotConfig() - - time = connections._get_max_idle_time(node, creds) - assert float(creds_idle_time) == time - - node.config._extra["databricks_compute"] = "alternate_compute" - time = connections._get_max_idle_time(node, creds) - assert float(compute_idle_time) == time - - creds.compute["alternate_compute"]["connect_max_idle"] = " 56 " - time = connections._get_max_idle_time(node, creds) - assert 56 == time diff --git a/tests/unit/test_query_config.py b/tests/unit/test_query_config.py new file mode 100644 index 000000000..61813200d --- /dev/null +++ b/tests/unit/test_query_config.py @@ -0,0 +1,93 @@ +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.databricks import connections +from dbt.adapters.databricks.connections import QueryConfigUtils, QueryContextWrapper +from dbt.adapters.databricks.credentials import DatabricksCredentials + + +class TestQueryConfigUtils: + """Test the various cases for determining a specified warehouse.""" + + @pytest.fixture(scope="class") + def err_msg(self): + return ( + "Compute resource foo does not exist or does not specify http_path," + " relation: a_relation" + ) + + @pytest.fixture(scope="class") + def path(self): + return "my_http_path" + + @pytest.fixture + def creds(self, path): + return DatabricksCredentials(http_path=path) + + def test_get_http_path__empty(self, path, creds): + assert QueryConfigUtils.get_http_path(QueryContextWrapper(), creds) == path + + def test_get_http_path__no_compute(self, path, creds): + assert ( + QueryConfigUtils.get_http_path(QueryContextWrapper(relation_name="a_relation"), creds) + == path + ) + + def test_get_http_path__missing_compute(self, creds, err_msg): + context = QueryContextWrapper(compute_name="foo", relation_name="a_relation") + with pytest.raises(DbtRuntimeError) as exc: + QueryConfigUtils.get_http_path(context, creds) + + assert err_msg in str(exc.value) + + def test_get_http_path__empty_compute(self, creds, err_msg): + context = QueryContextWrapper(compute_name="foo", relation_name="a_relation") + creds.compute = {"foo": {}} + with pytest.raises(DbtRuntimeError) as exc: + QueryConfigUtils.get_http_path(context, creds) + + assert err_msg in str(exc.value) + + def test_get_http_path__matching_compute(self, creds): + context = QueryContextWrapper(compute_name="foo", relation_name="a_relation") + creds.compute = {"foo": {"http_path": "alternate_path"}} + assert "alternate_path" == QueryConfigUtils.get_http_path(context, creds) + + def test_get_max_idle__no_config(self, creds): + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(), creds) + assert connections.DEFAULT_MAX_IDLE_TIME == time + + def test_get_max_idle__no_matching_compute(self, creds): + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(compute_name="foo"), creds) + assert connections.DEFAULT_MAX_IDLE_TIME == time + + def test_get_max_idle__compute_without_details(self, creds): + creds.compute = {"foo": {}} + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(compute_name="foo"), creds) + assert connections.DEFAULT_MAX_IDLE_TIME == time + + def test_get_max_idle__creds_but_no_context(self, creds): + creds.connect_max_idle = 77 + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(), creds) + assert 77 == time + + def test_get_max_idle__matching_compute_no_value(self, creds): + creds.connect_max_idle = 77 + creds.compute = {"foo": {}} + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(compute_name="foo"), creds) + assert 77 == time + + def test_get_max_idle__matching_compute(self, creds): + creds.compute = {"foo": {"connect_max_idle": "88"}} + creds.connect_max_idle = 77 + time = QueryConfigUtils.get_max_idle_time(QueryContextWrapper(compute_name="foo"), creds) + assert 88 == time + + def test_get_max_idle__invalid_config(self, creds): + creds.compute = {"foo": {"connect_max_idle": "bar"}} + + with pytest.raises(DbtRuntimeError) as info: + QueryConfigUtils.get_max_idle_time(QueryContextWrapper(compute_name="foo"), creds) + assert ( + "bar is not a valid value for connect_max_idle. Must be a number of seconds." + ) in str(info.value)