Skip to content

Commit

Permalink
Python model specific session handling to prevent using invalid sessi…
Browse files Browse the repository at this point in the history
…ons (#547)
  • Loading branch information
rcypher-databricks authored Jan 12, 2024
2 parents daffce1 + 014fcb7 commit 212d046
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

### Fixes

- Fix for issue where long-running python models led to invalid session errors ([544](https://github.com/databricks/dbt-databricks/pull/544))
- Added python model specific connection handling to prevent using invalid sessions ([547](https://github.com/databricks/dbt-databricks/pull/547))
- Allow schema to be specified in testing (thanks @case-k-git!) ([538](https://github.com/databricks/dbt-databricks/pull/538))

## dbt-databricks 1.7.3 (Dec 12, 2023)
Expand Down
67 changes: 59 additions & 8 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,11 +737,26 @@ class DatabricksDBTConnection(Connection):
thread_identifier: Tuple[int, int] = (0, 0)
max_idle_time: int = DEFAULT_MAX_IDLE_TIME

# If the connection is being used for a model we want to track the model language.
# We do this because we need special handling for python models. Python models will
# acquire a connection, but do not actually use it to run the model. This can lead to the
# session timing out on the back end. However, when the connection is released we set the
# last_used_time, essentially indicating that the connection was in use while the python
# model was running. So the session is not refreshed by idle connection cleanup and errors
# the next time it is used.
language: Optional[str] = None

def _acquire(self, node: Optional[ResultNode]) -> None:
"""Indicate that this connection is in use."""
logger.debug(f"DatabricksDBTConnection._acquire: {self._get_conn_info_str()}")
self._log_usage(node)
self.acquire_release_count += 1
if self.last_used_time is None:
self.last_used_time = time.time()
if node and hasattr(node, "language"):
self.language = node.language
else:
self.language = None

def _release(self) -> None:
"""Indicate that this connection is not in use."""
Expand All @@ -751,7 +766,9 @@ def _release(self) -> None:
if self.acquire_release_count > 0:
self.acquire_release_count -= 1

if self.acquire_release_count == 0:
# We don't update the last_used_time for python models because the python model
# is submitted through a different mechanism and doesn't actually use the connection.
if self.acquire_release_count == 0 and self.language != "python":
self.last_used_time = time.time()

def _get_idle_time(self) -> float:
Expand All @@ -765,7 +782,7 @@ def _get_conn_info_str(self) -> str:
return (
f"name: {self.name}, thread: {self.thread_identifier}, "
f"compute: `{self.compute_name}`, acquire_release_count: {self.acquire_release_count},"
f" idle time: {self._get_idle_time()}s"
f" idle time: {self._get_idle_time()}s, language: {self.language}"
)

def _log_usage(self, node: Optional[ResultNode]) -> None:
Expand All @@ -783,6 +800,13 @@ def _log_usage(self, node: Optional[ResultNode]) -> None:
else:
logger.debug(f"Thread {self.thread_identifier} using default compute resource.")

def _reset_handle(self, open: Callable[[Connection], Connection]) -> None:
logger.debug(f"DatabricksDBTConnection._reset_handle: {self._get_conn_info_str()}")
self.handle = LazyHandle(open)
# Reset last_used_time to None because by refreshing this connection becomes associated
# with a new session that hasn't been used yet.
self.last_used_time = None


class DatabricksConnectionManager(SparkConnectionManager):
TYPE: str = "databricks"
Expand Down Expand Up @@ -889,6 +913,19 @@ def release(self) -> None:

conn._release()

# override
@classmethod
def close(cls, connection: Connection) -> Connection:
if not USE_LONG_SESSIONS:
return super().close(connection)

try:
return super().close(connection)
except Exception as e:
logger.warning(f"ignoring error when closing connection: {e}")
connection.state = ConnectionState.CLOSED
return connection

# override
def cleanup_all(self) -> None:
if not USE_LONG_SESSIONS:
Expand Down Expand Up @@ -1063,12 +1100,26 @@ def _cleanup_idle_connections(self) -> None:
), "This path, '_cleanup_idle_connections', should only be reachable with USE_LONG_SESSIONS"

with self.lock:
for thread_conns in self.threads_compute_connections.values():
for conn in thread_conns.values():
if conn.acquire_release_count == 0 and conn._idle_too_long():
logger.debug(f"closing idle connection: {conn._get_conn_info_str()}")
self.close(conn)
conn.handle = LazyHandle(self._open2)
# Get all connections associated with this thread. There can be multiple connections
# if different models use different compute resources
thread_conns = self._get_compute_connections()
for conn in thread_conns.values():
# Generally speaking we only want to close/refresh the connection if the
# acquire_release_count is zero. i.e. the connection is not currently in use.
# However python models acquire a connection then run the pyton model, which
# doesn't actually use the connection. If the python model takes lone enought to
# run the connection can be idle long enough to timeout on the back end.
# If additional sql needs to be run after the python model, but before the
# connection is released, the connection needs to be refreshed or there will
# be a failure. Making an exception when language is 'python' allows the
# the call to _cleanup_idle_connections from get_thread_connection to refresh the
# connection in this scenario.
if (
conn.acquire_release_count == 0 or conn.language == "python"
) and conn._idle_too_long():
logger.debug(f"closing idle connection: {conn._get_conn_info_str()}")
self.close(conn)
conn._reset_handle(self._open2)

def get_thread_connection(self) -> Connection:
if USE_LONG_SESSIONS:
Expand Down

0 comments on commit 212d046

Please sign in to comment.