diff --git a/CHANGELOG.md b/CHANGELOG.md index f1d0df1d1..b5288d71b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ### Fixes -- Fix for issue where long-running python models led to invalid session errors ([544](https://github.com/databricks/dbt-databricks/pull/544)) +- Added python model specific connection handling to prevent using invalid sessions ([547](https://github.com/databricks/dbt-databricks/pull/547)) - Allow schema to be specified in testing (thanks @case-k-git!) ([538](https://github.com/databricks/dbt-databricks/pull/538)) ## dbt-databricks 1.7.3 (Dec 12, 2023) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 2e429062f..63fb9f5b4 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -737,11 +737,26 @@ class DatabricksDBTConnection(Connection): thread_identifier: Tuple[int, int] = (0, 0) max_idle_time: int = DEFAULT_MAX_IDLE_TIME + # If the connection is being used for a model we want to track the model language. + # We do this because we need special handling for python models. Python models will + # acquire a connection, but do not actually use it to run the model. This can lead to the + # session timing out on the back end. However, when the connection is released we set the + # last_used_time, essentially indicating that the connection was in use while the python + # model was running. So the session is not refreshed by idle connection cleanup and errors + # the next time it is used. + language: Optional[str] = None + def _acquire(self, node: Optional[ResultNode]) -> None: """Indicate that this connection is in use.""" logger.debug(f"DatabricksDBTConnection._acquire: {self._get_conn_info_str()}") self._log_usage(node) self.acquire_release_count += 1 + if self.last_used_time is None: + self.last_used_time = time.time() + if node and hasattr(node, "language"): + self.language = node.language + else: + self.language = None def _release(self) -> None: """Indicate that this connection is not in use.""" @@ -751,7 +766,9 @@ def _release(self) -> None: if self.acquire_release_count > 0: self.acquire_release_count -= 1 - if self.acquire_release_count == 0: + # We don't update the last_used_time for python models because the python model + # is submitted through a different mechanism and doesn't actually use the connection. + if self.acquire_release_count == 0 and self.language != "python": self.last_used_time = time.time() def _get_idle_time(self) -> float: @@ -765,7 +782,7 @@ def _get_conn_info_str(self) -> str: return ( f"name: {self.name}, thread: {self.thread_identifier}, " f"compute: `{self.compute_name}`, acquire_release_count: {self.acquire_release_count}," - f" idle time: {self._get_idle_time()}s" + f" idle time: {self._get_idle_time()}s, language: {self.language}" ) def _log_usage(self, node: Optional[ResultNode]) -> None: @@ -783,6 +800,13 @@ def _log_usage(self, node: Optional[ResultNode]) -> None: else: logger.debug(f"Thread {self.thread_identifier} using default compute resource.") + def _reset_handle(self, open: Callable[[Connection], Connection]) -> None: + logger.debug(f"DatabricksDBTConnection._reset_handle: {self._get_conn_info_str()}") + self.handle = LazyHandle(open) + # Reset last_used_time to None because by refreshing this connection becomes associated + # with a new session that hasn't been used yet. + self.last_used_time = None + class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" @@ -889,6 +913,19 @@ def release(self) -> None: conn._release() + # override + @classmethod + def close(cls, connection: Connection) -> Connection: + if not USE_LONG_SESSIONS: + return super().close(connection) + + try: + return super().close(connection) + except Exception as e: + logger.warning(f"ignoring error when closing connection: {e}") + connection.state = ConnectionState.CLOSED + return connection + # override def cleanup_all(self) -> None: if not USE_LONG_SESSIONS: @@ -1063,12 +1100,26 @@ def _cleanup_idle_connections(self) -> None: ), "This path, '_cleanup_idle_connections', should only be reachable with USE_LONG_SESSIONS" with self.lock: - for thread_conns in self.threads_compute_connections.values(): - for conn in thread_conns.values(): - if conn.acquire_release_count == 0 and conn._idle_too_long(): - logger.debug(f"closing idle connection: {conn._get_conn_info_str()}") - self.close(conn) - conn.handle = LazyHandle(self._open2) + # Get all connections associated with this thread. There can be multiple connections + # if different models use different compute resources + thread_conns = self._get_compute_connections() + for conn in thread_conns.values(): + # Generally speaking we only want to close/refresh the connection if the + # acquire_release_count is zero. i.e. the connection is not currently in use. + # However python models acquire a connection then run the pyton model, which + # doesn't actually use the connection. If the python model takes lone enought to + # run the connection can be idle long enough to timeout on the back end. + # If additional sql needs to be run after the python model, but before the + # connection is released, the connection needs to be refreshed or there will + # be a failure. Making an exception when language is 'python' allows the + # the call to _cleanup_idle_connections from get_thread_connection to refresh the + # connection in this scenario. + if ( + conn.acquire_release_count == 0 or conn.language == "python" + ) and conn._idle_too_long(): + logger.debug(f"closing idle connection: {conn._get_conn_info_str()}") + self.close(conn) + conn._reset_handle(self._open2) def get_thread_connection(self) -> Connection: if USE_LONG_SESSIONS: