Skip to content

Commit

Permalink
Use Databricks SDK for OAuth (#22)
Browse files Browse the repository at this point in the history
* Use Databricks SDK for OAuth

Signed-off-by: B-Step62 <[email protected]>

* comment

Signed-off-by: B-Step62 <[email protected]>

---------

Signed-off-by: B-Step62 <[email protected]>
  • Loading branch information
B-Step62 authored Sep 25, 2024
1 parent dcef363 commit ddf33cf
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 53 deletions.
2 changes: 1 addition & 1 deletion libs/databricks/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/databricks/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ codespell = "^2.2.6"
optional = true

[tool.poetry.group.test_integration.dependencies]
databricks-sdk = "^0.32.3"
langchain = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/langchain" }
langgraph = "^0.2.27"
pytest-timeout = "^2.3.1"
Expand Down
9 changes: 9 additions & 0 deletions libs/databricks/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,13 @@ def chatbot(state: State):
{"messages": [("user", "Subtract 5 from it")]},
config={"configurable": {"thread_id": "1"}},
)

# Interestingly, the agent sometimes mistakes the subtraction for addition:(
# In such case, the agent asks for a retry so we need one more step.
if "Let me try again." in response["messages"][-1].content:
response = graph.invoke(
{"messages": [("user", "Ok, try again")]},
config={"configurable": {"thread_id": "1"}},
)

assert "40" in response["messages"][-1].content
66 changes: 14 additions & 52 deletions libs/databricks/tests/integration_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
"""

import os
import time
from datetime import timedelta

import pytest
import requests
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.jobs import RunLifecycleStateV2State, TerminationTypeType


@pytest.mark.timeout(3600)
Expand All @@ -23,62 +24,23 @@ def test_vectorstore():
because the setup is too complex to run within a single python file.
Thereby, this test simply triggers the workflow by calling the REST API.
"""
required_env_vars = ["DATABRICKS_HOST", "DATABRICKS_TOKEN", "VS_TEST_JOB_ID"]
for var in required_env_vars:
assert os.getenv(var), f"Please set the environment variable {var}."

test_endpoint = os.getenv("DATABRICKS_HOST")
test_job_id = os.getenv("VS_TEST_JOB_ID")
headers = {
"Authorization": f"Bearer {os.getenv('DATABRICKS_TOKEN')}",
}
if not test_job_id:
raise RuntimeError("Please set the environment variable VS_TEST_JOB_ID")

w = WorkspaceClient()

# Check if there is any ongoing job run
response = requests.get(
f"{test_endpoint}/api/2.1/jobs/runs/list",
json={
"job_id": test_job_id,
"active_only": True,
},
headers=headers,
)
no_active_run = len(response.json().get("runs", [])) == 0
run_list = list(w.jobs.list_runs(job_id=test_job_id, active_only=True))
no_active_run = len(run_list) == 0
assert no_active_run, "There is an ongoing job run. Please wait for it to complete."

# Trigger the workflow
# TODO: We are going to replace this with the Databricks SDK once the vector store
# class is also migrated to the SDK.
response = requests.post(
f"{test_endpoint}/api/2.1/jobs/run-now",
json={
"job_id": test_job_id,
},
headers=headers,
)

assert response.status_code == 200, "Failed to trigger the workflow."

job_url = f"{test_endpoint}/jobs/{test_job_id}/runs/{response.json()['run_id']}"
response = w.jobs.run_now(job_id=test_job_id)
job_url = f"{w.config.host}/jobs/{test_job_id}/runs/{response.run_id}"
print(f"Started the job at {job_url}") # noqa: T201

# Wait for the job to complete
while True:
response = requests.get(
f"{test_endpoint}/api/2.1/jobs/runs/get",
json={
"run_id": response.json()["run_id"],
},
headers=headers,
)

assert response.status_code == 200, "Failed to get the job status."

status = response.json()["status"]
if status["state"] == "TERMINATED":
if status["termination_details"]["type"] == "SUCCESS":
break
else:
assert False, "Job failed. Please check the logs in the workspace."

time.sleep(60)
print("Job is still running...") # noqa: T201
result = response.result(timeout=timedelta(seconds=3600))
assert result.status.state == RunLifecycleStateV2State.TERMINATED
assert result.status.termination_details.type == TerminationTypeType.SUCCESS

0 comments on commit ddf33cf

Please sign in to comment.