Skip to content

Commit

Permalink
Adding Databricks SQL Connector
Browse files Browse the repository at this point in the history
  • Loading branch information
DanFitzgibbon authored and lucasfcosta committed Nov 1, 2024
1 parent 0ceab2e commit 7ba1eb5
Show file tree
Hide file tree
Showing 27 changed files with 1,374 additions and 101 deletions.
3 changes: 2 additions & 1 deletion apps/api/jupyter-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ openpyxl==3.1.2
mysqlclient==2.2.4
pymongo==4.8.0
snowflake-connector-python==3.12.2
snowflake-sqlalchemy==1.6.1
snowflake-sqlalchemy==1.6.1
databricks-sql-connector[sqlalchemy]==3.4.0
1 change: 1 addition & 0 deletions apps/api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"@briefer/database": "*",
"@briefer/editor": "*",
"@briefer/types": "*",
"@databricks/sql": "^1.8.4",
"@grpc/grpc-js": "^1.11.1",
"@jupyterlab/services": "^7.1.1",
"@kubernetes/client-node": "^0.20.0",
Expand Down
6 changes: 6 additions & 0 deletions apps/api/src/auth/token.ts
Original file line number Diff line number Diff line change
Expand Up @@ -318,5 +318,11 @@ export const isAuthorizedForDataSource = async (
const result = await prisma().snowflakeDataSource.findFirst(query)
return result !== null
}
case 'databrickssql': {
const result = await prisma().databricksSQLDataSource.findFirst(
query,
)
return result !== null
}
}
}
43 changes: 43 additions & 0 deletions apps/api/src/datasources/databrickssql.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { config } from '../config/index.js'
import prisma, { DatabricksSQLDataSource } from '@briefer/database'
import { DataSourceStatus } from './index.js'
import { pingDatabricksSQL } from '../python/query/databrickssql.js'

export async function ping(ds: DatabricksSQLDataSource): Promise<DatabricksSQLDataSource> {
const lastConnection = new Date()
const err = await pingDatabricksSQL(ds, config().DATASOURCES_ENCRYPTION_KEY)

if (!err) {
return updateConnStatus(ds, {
connStatus: 'online',
lastConnection,
})
}

return updateConnStatus(ds, { connStatus: 'offline', connError: err })
}

export async function updateConnStatus(
ds: DatabricksSQLDataSource,
status: DataSourceStatus
): Promise<DatabricksSQLDataSource> {
const newDs = await prisma().databricksSQLDataSource.update({
where: { id: ds.id },
data: {
connStatus: status.connStatus,
lastConnection:
status.connStatus === 'online' ? status.lastConnection : undefined,
connError:
status.connStatus === 'offline'
? JSON.stringify(status.connError)
: undefined,
},
})

return {
...ds,
connStatus: newDs.connStatus,
lastConnection: newDs.lastConnection?.toISOString() ?? null,
connError: newDs.connError,
}
}
6 changes: 6 additions & 0 deletions apps/api/src/datasources/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import * as mysql from './mysql.js'
import * as trino from './trino.js'
import * as sqlserver from './sqlserver.js'
import * as snowflake from './snowflake.js'
import * as databrickssql from './databrickssql.js'
import { DataSourceConnectionError } from '@briefer/types'
import { IOServer } from '../websocket/index.js'
import { broadcastDataSource } from '../websocket/workspace/data-sources.js'
Expand Down Expand Up @@ -39,6 +40,8 @@ export async function ping(
return trino.ping(ds.config.data)
case 'snowflake':
return snowflake.ping(ds.config.data)
case 'databrickssql':
return databrickssql.ping(ds.config.data)
}
})()
broadcastDataSource(socket, ds)
Expand Down Expand Up @@ -89,5 +92,8 @@ export async function updateConnStatus<T extends Pick<APIDataSource, 'config'>>(
case 'snowflake':
ds.config.data = await snowflake.updateConnStatus(ds.config.data, status)
return ds
case 'databrickssql':
ds.config.data = await databrickssql.updateConnStatus(ds.config.data, status)
return ds
}
}
30 changes: 30 additions & 0 deletions apps/api/src/datasources/structure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import { getTrinoSchema } from '../python/query/trino.js'
import { getSnowflakeSchema } from '../python/query/snowflake.js'
import { getAthenaSchema } from './athena.js'
import { getMySQLSchema } from './mysql.js'
import { getDatabricksSQLSchema } from '../python/query/databrickssql.js'
import { PythonExecutionError } from '../python/index.js'
import { getSqlServerSchema } from './sqlserver.js'
import { z } from 'zod'
Expand Down Expand Up @@ -134,6 +135,14 @@ async function getV2FromCache(
})
).structure
break
case 'databrickssql':
encrypted = (
await prisma().databricksSQLDataSource.findUniqueOrThrow({
where: { id: dataSourceId },
select: { structure: true },
})
).structure
break
}

if (encrypted === null) {
Expand Down Expand Up @@ -280,6 +289,12 @@ async function assignDataSourceSchemaId(
data: { dataSourceSchemaId: dbSchema.id },
})
return dbSchema.id
case 'databrickssql':
await prisma().databricksSQLDataSource.update({
where: { id: dataSourceId },
data: { dataSourceSchemaId: dbSchema.id },
})
return dbSchema.id
}
}

Expand Down Expand Up @@ -364,6 +379,14 @@ async function getFromCache(
})
).dataSourceSchema
break
case 'databrickssql':
schema = (
await prisma().databricksSQLDataSource.findUniqueOrThrow({
where: { id: dataSourceId },
select,
})
).dataSourceSchema
break
}

if (schema === null) {
Expand Down Expand Up @@ -584,6 +607,13 @@ async function _refreshDataSourceStructure(
onTable
)
break
case 'databrickssql':
await getDatabricksSQLSchema(
dataSource.config.data,
config().DATASOURCES_ENCRYPTION_KEY,
onTable
)
break
}

await updateQueue.onIdle()
Expand Down
64 changes: 64 additions & 0 deletions apps/api/src/python/query/databrickssql.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import { v4 as uuidv4 } from 'uuid'
import { DatabricksSQLDataSource, getDatabaseURL } from '@briefer/database'
import { RunQueryResult, SuccessRunQueryResult } from '@briefer/types'
import {
getSQLAlchemySchema,
makeSQLAlchemyQuery,
pingSQLAlchemy,
} from './sqlalchemy.js'
import { OnTable } from '../../datasources/structure.js'

export async function makeDatabricksSQLQuery(
workspaceId: string,
sessionId: string,
queryId: string,
dataframeName: string,
datasource: DatabricksSQLDataSource,
encryptionKey: string,
sql: string,
onProgress: (result: SuccessRunQueryResult) => void
): Promise<[Promise<RunQueryResult>, () => Promise<void>]> {
const databaseUrl = await getDatabaseURL(
{ type: 'databrickssql', data: datasource },
encryptionKey
)

const jobId = uuidv4()
const query = `${sql} -- Briefer jobId: ${jobId}`

return makeSQLAlchemyQuery(
workspaceId,
sessionId,
dataframeName,
databaseUrl,
'databrickssql',
jobId,
query,
queryId,
onProgress
)
}

export function pingDatabricksSQL(
ds: DatabricksSQLDataSource,
encryptionKey: string
): Promise<null | Error> {
return pingSQLAlchemy(
{ type: 'databrickssql', data: ds },
encryptionKey,
null
)
}

export function getDatabricksSQLSchema(
ds: DatabricksSQLDataSource,
encryptionKey: string,
onTable: OnTable
): Promise<void> {
return getSQLAlchemySchema(
{ type: 'databrickssql', data: ds },
encryptionKey,
null,
onTable
)
}
13 changes: 13 additions & 0 deletions apps/api/src/python/query/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { makeSnowflakeQuery } from './snowflake.js'
import { updateConnStatus } from '../../datasources/index.js'
import { getJupyterManager } from '../../jupyter/index.js'
import { makeSQLServerQuery } from './sqlserver.js'
import { makeDatabricksSQLQuery } from './databrickssql.js'

export async function makeSQLQuery(
workspaceId: string,
Expand Down Expand Up @@ -147,6 +148,18 @@ export async function makeSQLQuery(
onProgress
)
break
case 'databrickssql':
result = await makeDatabricksSQLQuery(
workspaceId,
sessionId,
queryId,
dataframeName,
datasource.data,
encryptionKey,
sql,
onProgress
)
break
}

result[0].then(async (r) => {
Expand Down
3 changes: 2 additions & 1 deletion apps/api/src/python/query/sqlalchemy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ export async function makeSQLAlchemyQuery(
| 'psql'
| 'redshift'
| 'trino'
| 'snowflake',
| 'snowflake'
| 'databrickssql',
jobId: string,
query: string,
queryId: string,
Expand Down
2 changes: 2 additions & 0 deletions apps/api/src/python/writeback/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,7 @@ export async function writeback(
case 'snowflake':
case 'trino':
throw new Error(`${datasource.type} writeback not implemented`)
case 'databrickssql':
throw new Error(`${datasource.type} writeback not implemented`)
}
}
42 changes: 42 additions & 0 deletions apps/api/src/v1/workspaces/workspace/data-sources/data-source.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ import {
getSnowflakeDataSource,
updateSnowflakeDataSource,
deleteSnowflakeDataSource,
getDatabricksSQLDataSource,
updateDatabricksSQLDataSource,
deleteDatabricksSQLDataSource,
} from '@briefer/database'
import { z } from 'zod'
import { getParam } from '../../../../utils/express.js'
Expand Down Expand Up @@ -139,6 +142,19 @@ const dataSourceRouter = (socketServer: IOServer) => {
notes: z.string(),
}),
}),
z.object({
type: z.literal('databrickssql'),
data: z.object({
id: z.string().min(1),
name: z.string().min(1),
hostname: z.string().min(1),
http_path: z.string().min(1),
token: z.string().min(1),
catalog: z.string(),
schema: z.string(),
notes: z.string(),
}),
}),
])

router.put('/', async (req, res) => {
Expand All @@ -162,6 +178,7 @@ const dataSourceRouter = (socketServer: IOServer) => {
getSQLServerDataSource(workspaceId, dataSourceId),
getTrinoDataSource(workspaceId, dataSourceId),
getSnowflakeDataSource(workspaceId, dataSourceId),
getDatabricksSQLDataSource(workspaceId, dataSourceId),
])
).find((e) => e !== null)
if (!existingDb) {
Expand Down Expand Up @@ -298,6 +315,17 @@ const dataSourceRouter = (socketServer: IOServer) => {
)
break
}
case 'databrickssql': {
await updateDatabricksSQLDataSource(
{
...data.data,
id: dataSourceId,
token: data.data.token === ''? undefined : data.data.token,
},
config().DATASOURCES_ENCRYPTION_KEY
)
break
}
}

const ds = await getDatasource(workspaceId, dataSourceId, data.type)
Expand Down Expand Up @@ -444,6 +472,19 @@ const dataSourceRouter = (socketServer: IOServer) => {
}
}

const targetDatabricksSQLDb = await recoverFromNotFound(
deleteDatabricksSQLDataSource(workspaceId, targetId)
)
if (targetDatabricksSQLDb) {
return {
status: 200,
payload: {
type: 'databrickssql',
data: targetDatabricksSQLDb,
},
}
}

return { status: 404 }
}

Expand All @@ -468,6 +509,7 @@ const dataSourceRouter = (socketServer: IOServer) => {
z.literal('trino'),
z.literal('sqlserver'),
z.literal('snowflake'),
z.literal('databrickssql'),
]),
})
router.post('/ping', async (req, res) => {
Expand Down
Loading

0 comments on commit 7ba1eb5

Please sign in to comment.