Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rag SQL AI #310

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"lru-cache": "^10.2.2",
"mssql": "^11.0.1",
"node-fetch": "^3.3.2",
"openai": "^4.77.3",
"p-all": "^5.0.0",
"p-queue": "^8.0.1",
"parse-duration": "^1.1.0",
Expand Down
55 changes: 53 additions & 2 deletions apps/api/src/datasources/structure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import prisma, {
DataSource,
DataSourceSchema as DBDataSourceSchema,
decrypt,
getWorkspaceById,
getWorkspaceWithSecrets,
} from '@briefer/database'
import { IOServer } from '../websocket/index.js'
import {
Expand Down Expand Up @@ -37,6 +39,7 @@ import { PythonExecutionError } from '../python/index.js'
import { getSqlServerSchema } from './sqlserver.js'
import { z } from 'zod'
import { splitEvery } from 'ramda'
import { createEmbedding } from '../embedding.js'

function decryptDBData(
dataSourceId: string,
Expand Down Expand Up @@ -519,19 +522,54 @@ async function _refreshDataSourceStructure(
socketServer: IOServer,
dataSource: APIDataSource
) {
const workspace = await getWorkspaceWithSecrets(
dataSource.config.data.workspaceId
)
if (!workspace) {
throw new Error(
`Failed to find Workspace(${dataSource.config.data.workspaceId}) for DataSource(${dataSource.config.data.id})`
)
}

const updateQueue = new PQueue({ concurrency: 1 })
const tables: { schema: string; table: string }[] = []
let defaultSchema = ''
const onTable: OnTable = (schema, tableName, table, defaultSchema) => {
defaultSchema = defaultSchema || schema
tables.push({ schema, table: tableName })
updateQueue.add(async () => {
const openAiApiKey = workspace.secrets?.openAiApiKey
let embedding: number[] | null = null
if (openAiApiKey) {
let ddl = `CREATE TABLE ${schema}.${tableName} (\n`
for (const c of table.columns) {
ddl += ` ${c.name} ${c.type}\n`
}
try {
embedding = await createEmbedding(
ddl,
decrypt(openAiApiKey, config().WORKSPACE_SECRETS_ENCRYPTION_KEY)
)
} catch (err) {
logger().error(
{
err,
dataSourceId: dataSource.config.data.id,
dataSourceType: dataSource.config.type,
schemaName: schema,
tableName,
},
'Failed to create embedding'
)
}
}

await prisma().dataSourceSchema.update({
where: { id: dataSource.structure.id },
data: { defaultSchema },
})

await prisma().dataSourceSchemaTable.upsert({
const dbSchemaTable = await prisma().dataSourceSchemaTable.upsert({
where: {
dataSourceSchemaId_schema_name: {
dataSourceSchemaId: dataSource.structure.id,
Expand All @@ -550,6 +588,11 @@ async function _refreshDataSourceStructure(
},
})

if (embedding) {
await prisma()
.$queryRaw`UPDATE "DataSourceSchemaTable" SET embedding = ${embedding} WHERE id = ${dbSchemaTable.id}::uuid`
}

broadcastDataSource(socketServer, dataSource)
broadcastDataSourceSchemaTableUpdate(
socketServer,
Expand Down Expand Up @@ -831,7 +874,15 @@ export const OnTableProgress = z.object({
})
export type OnTableProgress = z.infer<typeof OnTableProgress>

export async function* listSchemaTables(dataSources: APIDataSource[]) {
export type SchemaTableItem = {
schemaName: string
tableName: string
dataSourceId: string
table: DataSourceTable
}
export async function* listSchemaTables(
dataSources: APIDataSource[]
): AsyncGenerator<SchemaTableItem> {
const schemaToDataSourceId = new Map(
dataSources.map((d) => [d.structure.id, d.config.data.id])
)
Expand Down
55 changes: 55 additions & 0 deletions apps/api/src/embedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import { createHash } from 'crypto'
import prisma from '@briefer/database'
import { OpenAI } from 'openai'
import { head } from 'ramda'
import { z } from 'zod'
import { logger } from './logger.js'
import { jsonString } from '@briefer/types'

export async function createEmbedding(input: string, openAiApiKey: string) {
const model = 'text-embedding-3-small'
const inputChecksum = createHash('sha256').update(input).digest('hex')
const rawExistingEmbeddings = await prisma()
.$queryRaw`SELECT embedding::text FROM "EmbeddingCache" WHERE "inputChecksum" = ${inputChecksum} AND model = ${model}`

const parsedExistingEmbeddings = z
.array(z.object({ embedding: jsonString.pipe(z.array(z.number())) }))
.safeParse(rawExistingEmbeddings)
if (parsedExistingEmbeddings.success) {
const embedding = head(parsedExistingEmbeddings.data)?.embedding
if (embedding) {
return embedding
}
} else {
logger().error(
{
err: parsedExistingEmbeddings.error,
},
'Failed to parse existing embeddings'
)
}

const openai = new OpenAI({ apiKey: openAiApiKey })
const embeddingResponse = await openai.embeddings.create({
model,
input,
})
const embedding = head(embeddingResponse.data)?.embedding
if (!embedding) {
throw new Error('OpenAI did not return any embeddings')
}

try {
await prisma()
.$queryRaw`INSERT INTO "EmbeddingCache" ("inputChecksum", model, embedding) VALUES (${inputChecksum}, ${model}, ${embedding}::vector) ON CONFLICT DO NOTHING`
} catch (err) {
logger().error(
{
err,
},
'Failed to insert embedding into cache'
)
}

return embedding
}
9 changes: 9 additions & 0 deletions apps/api/src/websocket/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ interface EmitEvents {
workspaceId: string
dataSource: APIDataSource
}) => void
'workspace-datasource-schema-tables': (msg: {
workspaceId: string
tables: {
dataSourceId: string
schemaName: string
tableName: string
table: DataSourceTable
}[]
}) => void
'workspace-datasource-schema-table-update': (msg: {
workspaceId: string
dataSourceId: string
Expand Down
17 changes: 15 additions & 2 deletions apps/api/src/websocket/workspace/data-sources.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { Session } from '../../types.js'
import {
fetchDataSourceStructure,
listSchemaTables,
SchemaTableItem,
} from '../../datasources/structure.js'
import { DataSourceTable, uuidSchema } from '@briefer/types'

Expand Down Expand Up @@ -99,10 +100,22 @@ async function emitSchemas(
workspaceId: string,
dataSources: APIDataSource[]
) {
let batch: SchemaTableItem[] = []
for await (const schemaTable of listSchemaTables(dataSources)) {
socket.emit('workspace-datasource-schema-table-update', {
batch.push(schemaTable)
if (batch.length >= 100) {
socket.emit('workspace-datasource-schema-tables', {
workspaceId,
tables: batch,
})
batch = []
}
}

if (batch.length > 0) {
socket.emit('workspace-datasource-schema-tables', {
workspaceId,
...schemaTable,
tables: batch,
})
}
}
Expand Down
60 changes: 55 additions & 5 deletions apps/api/src/yjs/v2/executor/ai/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import {
updateSQLAISuggestions,
} from '@briefer/editor'
import * as Y from 'yjs'
import {
import prisma, {
listDataSources,
getWorkspaceWithSecrets,
DataSource,
decrypt,
} from '@briefer/database'
import { logger } from '../../../../logger.js'
import { sqlEditStreamed } from '../../../../ai-api.js'
Expand All @@ -24,7 +25,10 @@ import {
fetchDataSourceStructureFromCache,
listSchemaTables,
} from '../../../../datasources/structure.js'
import { DataSourceStructureStateV3 } from '@briefer/types'
import { DataSourceStructureStateV3, uuidSchema } from '@briefer/types'
import { createEmbedding } from '../../../../embedding.js'
import { config } from '../../../../config/index.js'
import { z } from 'zod'

async function editWithAI(
workspaceId: string,
Expand Down Expand Up @@ -70,7 +74,12 @@ async function editWithAI(
dataSource.data.id,
dataSource.type
)
const tableInfo = await tableInfoFromStructure(dataSource, structure)
const tableInfo = await retrieveTableInfoForQuestion(
dataSource,
structure,
instructions,
workspace?.secrets?.openAiApiKey ?? null
)

event(assistantModelId)

Expand Down Expand Up @@ -110,6 +119,47 @@ async function editWithAI(
)
}

async function retrieveTableInfoForQuestion(
datasource: DataSource,
structure: DataSourceStructureStateV3 | null,
question: string,
openAiApiKey: string | null
): Promise<string | null> {
if (!structure || !openAiApiKey) {
return tableInfoFromStructure(datasource, structure)
}

const questionEmbedding = await createEmbedding(
question,
decrypt(openAiApiKey, config().WORKSPACE_SECRETS_ENCRYPTION_KEY)
)

const raw = await prisma()
.$queryRaw`SELECT t.id, t.embedding <=> ${questionEmbedding}::vector AS distance FROM "DataSourceSchemaTable" t INNER JOIN "DataSourceSchema" s ON s.id = t."dataSourceSchemaId" WHERE s.id = ${structure.id}::uuid ORDER BY distance LIMIT 30`

const result = z.array(z.object({ id: uuidSchema })).parse(raw)

const tables = await prisma().dataSourceSchemaTable.findMany({
where: {
id: { in: result.map((r) => r.id) },
},
})

let tableInfo = ''
for (const table of tables) {
tableInfo += `${table.schema}.${table.name}\n`
const columns = z
.array(z.object({ name: z.string(), type: z.string() }))
.parse(table.columns)
for (const column of columns) {
tableInfo += `${column.name} ${column.type}\n`
}
tableInfo += '\n'
}

return tableInfo.trim()
}

async function tableInfoFromStructure(
config: DataSource,
structure: DataSourceStructureStateV3 | null
Expand All @@ -121,8 +171,8 @@ async function tableInfoFromStructure(
let result = ''
for await (const schemaTable of listSchemaTables([{ config, structure }])) {
result += `${schemaTable.schemaName}.${schemaTable.tableName}\n`
for (const columns of schemaTable.table.columns) {
result += `${columns.name} ${columns.type}\n`
for (const column of schemaTable.table.columns) {
result += `${column.name} ${column.type}\n`
}
result += '\n'
}
Expand Down
35 changes: 34 additions & 1 deletion apps/web/src/hooks/useDatasources.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ type UseDataSources = [
schemas: Schemas
isLoading: boolean
},
API
API,
]
export const useDataSources = (workspaceId: string): UseDataSources => {
const [state, api] = useContext(Context)
Expand Down Expand Up @@ -186,6 +186,38 @@ export function DataSourcesProvider(props: Props) {
onDataSourceSchemaTableUpdate
)

const onDataSourceSchemaTables = ({
workspaceId,
tables,
}: {
workspaceId: string
tables: (DataSourceTable & {
dataSourceId: string
schemaName: string
tableName: string
table: DataSourceTable
})[]
}) => {
setState((state) => {
tables.forEach(({ dataSourceId, schemaName, tableName, table }) => {
const datasources = state.get(workspaceId)?.datasources ?? List()
const allSchemas = state.get(workspaceId)?.schemas ?? Map()
const dataSourceSchemas = allSchemas.get(dataSourceId) ?? Map()
const schema = dataSourceSchemas.get(schemaName)
const tables = {
...(schema?.tables ?? {}),
[tableName]: table,
}
state = state.set(workspaceId, {
datasources,
schemas: allSchemas.setIn([dataSourceId, schemaName], { tables }),
})
})
return state
})
}
socket.on('workspace-datasource-schema-tables', onDataSourceSchemaTables)

const onDataSourceSchemaTableRemoved = ({
workspaceId,
dataSourceId,
Expand Down Expand Up @@ -220,6 +252,7 @@ export function DataSourcesProvider(props: Props) {
'workspace-datasource-schema-table-update',
onDataSourceSchemaTableUpdate
)
socket.off('workspace-datasource-schema-tables', onDataSourceSchemaTables)
socket.off(
'workspace-datasource-schema-table-removed',
onDataSourceSchemaTableRemoved
Expand Down
7 changes: 7 additions & 0 deletions dev/postgresql/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM postgres:16

# Install pgvector
RUN apt-get update && apt-get install -y postgresql-16-pgvector
vieiralucas marked this conversation as resolved.
Show resolved Hide resolved

# Clean up
RUN apt-get clean && rm -rf /var/lib/apt/lists/*
4 changes: 3 additions & 1 deletion docker-compose.dev.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
services:
postgres:
image: postgres
build:
context: dev/postgresql
dockerfile: Dockerfile
ports:
- '5432:5432'
environment:
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
postgres:
image: postgres
image: pgvector/pgvector:pg16
environment:
POSTGRES_DB: 'briefer'
POSTGRES_USER: ${POSTGRES_USERNAME:?error}
Expand Down
Loading
Loading