From 61c0ab10dc07b53f7e306c04337989a8692f8e03 Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Fri, 12 Jan 2024 15:50:23 -0800 Subject: [PATCH 1/3] Adding a default retry strategy in python submissions --- dbt/adapters/databricks/python_submissions.py | 53 +++++++++++++------ tests/unit/python/test_python_submissions.py | 22 ++++---- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/dbt/adapters/databricks/python_submissions.py b/dbt/adapters/databricks/python_submissions.py index bada5548d..e3091854e 100644 --- a/dbt/adapters/databricks/python_submissions.py +++ b/dbt/adapters/databricks/python_submissions.py @@ -1,19 +1,23 @@ from typing import Any, Dict, Tuple, Optional, Callable +from requests import Session + from dbt.adapters.databricks.__version__ import version from dbt.adapters.databricks.connections import DatabricksCredentials from dbt.adapters.databricks import utils import base64 import time -import requests import uuid +from urllib3.util.retry import Retry + from dbt.events import AdapterLogger import dbt.exceptions from dbt.adapters.base import PythonJobHelper from dbt.adapters.spark import __version__ from databricks.sdk.core import CredentialsProvider +from requests.adapters import HTTPAdapter logger = AdapterLogger("Databricks") @@ -31,6 +35,13 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No self.parsed_model = parsed_model self.timeout = self.get_timeout() self.polling_interval = DEFAULT_POLLING_INTERVAL + + # This should be passed in, but not sure where this is actually instantiated + retry_strategy = Retry(total=4, backoff_factor=0.5) + adapter = HTTPAdapter(max_retries=retry_strategy) + self.session = Session() + self.session.mount("https://", adapter) + self.check_credentials() self.auth_header = { "Authorization": f"Bearer {self.credentials.token}", @@ -53,7 +64,7 @@ def check_credentials(self) -> None: ) def _create_work_dir(self, path: str) -> None: - response = requests.post( + response = self.session.post( f"https://{self.credentials.host}/api/2.0/workspace/mkdirs", headers=self.auth_header, json={ @@ -73,7 +84,7 @@ def _update_with_acls(self, cluster_dict: dict) -> dict: def _upload_notebook(self, path: str, compiled_code: str) -> None: b64_encoded_content = base64.b64encode(compiled_code.encode()).decode() - response = requests.post( + response = self.session.post( f"https://{self.credentials.host}/api/2.0/workspace/import", headers=self.auth_header, json={ @@ -118,7 +129,7 @@ def _submit_job(self, path: str, cluster_spec: dict) -> str: libraries.append(lib) job_spec.update({"libraries": libraries}) # type: ignore - submit_response = requests.post( + submit_response = self.session.post( f"https://{self.credentials.host}/api/2.1/jobs/runs/submit", headers=self.auth_header, json=job_spec, @@ -143,7 +154,7 @@ def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> No run_id = self._submit_job(whole_file_path, cluster_spec) self.polling( - status_func=requests.get, + status_func=self.session.get, status_func_kwargs={ "url": f"https://{self.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}", "headers": self.auth_header, @@ -155,7 +166,7 @@ def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> No ) # get end state to return to user - run_output = requests.get( + run_output = self.session.get( f"https://{self.credentials.host}" f"/api/2.1/jobs/runs/get-output?run_id={run_id}", headers=self.auth_header, ) @@ -217,11 +228,16 @@ def submit(self, compiled_code: str) -> None: class DBContext: def __init__( - self, credentials: DatabricksCredentials, cluster_id: str, auth_header: dict + self, + credentials: DatabricksCredentials, + cluster_id: str, + auth_header: dict, + session: Session, ) -> None: self.auth_header = auth_header self.cluster_id = cluster_id self.host = credentials.host + self.session = session def create(self) -> str: # https://docs.databricks.com/dev-tools/api/1.2/index.html#create-an-execution-context @@ -235,7 +251,7 @@ def create(self) -> str: if current_status != "RUNNING": self._wait_for_cluster_to_start() - response = requests.post( + response = self.session.post( f"https://{self.host}/api/1.2/contexts/create", headers=self.auth_header, json={ @@ -251,7 +267,7 @@ def create(self) -> str: def destroy(self, context_id: str) -> str: # https://docs.databricks.com/dev-tools/api/1.2/index.html#delete-an-execution-context - response = requests.post( + response = self.session.post( f"https://{self.host}/api/1.2/contexts/destroy", headers=self.auth_header, json={ @@ -268,7 +284,7 @@ def destroy(self, context_id: str) -> str: def get_cluster_status(self) -> Dict: # https://docs.databricks.com/dev-tools/api/latest/clusters.html#get - response = requests.get( + response = self.session.get( f"https://{self.host}/api/2.0/clusters/get", headers=self.auth_header, json={"cluster_id": self.cluster_id}, @@ -291,7 +307,7 @@ def start_cluster(self) -> None: logger.debug(f"Sending restart command for cluster id {self.cluster_id}") - response = requests.post( + response = self.session.post( f"https://{self.host}/api/2.0/clusters/start", headers=self.auth_header, json={"cluster_id": self.cluster_id}, @@ -327,15 +343,20 @@ def get_elapsed() -> float: class DBCommand: def __init__( - self, credentials: DatabricksCredentials, cluster_id: str, auth_header: dict + self, + credentials: DatabricksCredentials, + cluster_id: str, + auth_header: dict, + session: Session, ) -> None: self.auth_header = auth_header self.cluster_id = cluster_id self.host = credentials.host + self.session = session def execute(self, context_id: str, command: str) -> str: # https://docs.databricks.com/dev-tools/api/1.2/index.html#run-a-command - response = requests.post( + response = self.session.post( f"https://{self.host}/api/1.2/commands/execute", headers=self.auth_header, json={ @@ -354,7 +375,7 @@ def execute(self, context_id: str, command: str) -> str: def status(self, context_id: str, command_id: str) -> Dict[str, Any]: # https://docs.databricks.com/dev-tools/api/1.2/index.html#get-information-about-a-command - response = requests.get( + response = self.session.get( f"https://{self.host}/api/1.2/commands/status", headers=self.auth_header, params={ @@ -383,8 +404,8 @@ def submit(self, compiled_code: str) -> None: config = {"existing_cluster_id": self.cluster_id} self._submit_through_notebook(compiled_code, self._update_with_acls(config)) else: - context = DBContext(self.credentials, self.cluster_id, self.auth_header) - command = DBCommand(self.credentials, self.cluster_id, self.auth_header) + context = DBContext(self.credentials, self.cluster_id, self.auth_header, self.session) + command = DBCommand(self.credentials, self.cluster_id, self.auth_header, self.session) context_id = context.create() try: command_id = command.execute(context_id, compiled_code) diff --git a/tests/unit/python/test_python_submissions.py b/tests/unit/python/test_python_submissions.py index d84de5921..975bb24fc 100644 --- a/tests/unit/python/test_python_submissions.py +++ b/tests/unit/python/test_python_submissions.py @@ -1,25 +1,29 @@ import unittest -from unittest.mock import patch, Mock +from unittest.mock import Mock +from mock import PropertyMock from dbt.adapters.databricks.connections import DatabricksCredentials from dbt.adapters.databricks.python_submissions import DBContext, BaseDatabricksHelper class TestDatabricksPythonSubmissions(unittest.TestCase): - @patch("requests.get") - @patch("requests.post") - def test_start_cluster_returns_on_receiving_running_state(self, mock_post, mock_get): + def test_start_cluster_returns_on_receiving_running_state(self): + session_mock = Mock() # Mock the start command - mock_post.return_value.status_code = 200 + post_mock = Mock() + post_mock.status_code = 200 + session_mock.post.return_value = post_mock # Mock the status command - mock_get.return_value.status_code = 200 - mock_get.return_value.json = Mock(return_value={"state": "RUNNING"}) + get_mock = Mock() + get_mock.status_code = 200 + get_mock.json.return_value = {"state": "RUNNING"} + session_mock.get.return_value = get_mock - context = DBContext(Mock(), None, None) + context = DBContext(Mock(), None, None, session_mock) context.start_cluster() - mock_get.assert_called_once() + session_mock.get.assert_called_once() class DatabricksTestHelper(BaseDatabricksHelper): From 340c806f42642b079d4b8bc888fd1dbe45618c0d Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Fri, 12 Jan 2024 16:08:26 -0800 Subject: [PATCH 2/3] passing linter --- tests/unit/python/test_python_submissions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/python/test_python_submissions.py b/tests/unit/python/test_python_submissions.py index 975bb24fc..e3a6f5741 100644 --- a/tests/unit/python/test_python_submissions.py +++ b/tests/unit/python/test_python_submissions.py @@ -1,7 +1,6 @@ import unittest from unittest.mock import Mock -from mock import PropertyMock from dbt.adapters.databricks.connections import DatabricksCredentials from dbt.adapters.databricks.python_submissions import DBContext, BaseDatabricksHelper From 6c43ba79bfdfae1b2242c018b3cd241d5e5f4dfa Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Fri, 12 Jan 2024 16:12:04 -0800 Subject: [PATCH 3/3] changelog --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb2a6b9cc..324547ffc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,10 +2,14 @@ ### Fixes -- Added python model specific connection handling to prevent using invalid sessions ([547](https://github.com/databricks/dbt-databricks/pull/547)) +- Added python model specific connection handling to prevent using invalid sessions ([547](https://github.com/databricks/dbt-databricks/pull/547)) - Allow schema to be specified in testing (thanks @case-k-git!) ([538](https://github.com/databricks/dbt-databricks/pull/538)) - Fix dbt incremental_strategy behavior by fixing schema table existing check (thanks @case-k-git!) ([530](https://github.com/databricks/dbt-databricks/pull/530)) +### Under the Hood + +- Adding retries around API calls in python model submission ([549](https://github.com/databricks/dbt-databricks/pull/549)) + ## dbt-databricks 1.7.3 (Dec 12, 2023) ### Fixes