Skip to content

Commit

Permalink
Merge pull request #310 from briefercloud/rag2
Browse files Browse the repository at this point in the history
Rag SQL AI
  • Loading branch information
vieiralucas authored Jan 7, 2025
2 parents 01bddcc + a100065 commit f210ac4
Show file tree
Hide file tree
Showing 15 changed files with 308 additions and 16 deletions.
1 change: 1 addition & 0 deletions apps/api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"lru-cache": "^10.2.2",
"mssql": "^11.0.1",
"node-fetch": "^3.3.2",
"openai": "^4.77.3",
"p-all": "^5.0.0",
"p-queue": "^8.0.1",
"parse-duration": "^1.1.0",
Expand Down
45 changes: 44 additions & 1 deletion apps/api/src/datasources/structure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import prisma, {
DataSource,
DataSourceSchema as DBDataSourceSchema,
decrypt,
getWorkspaceById,
getWorkspaceWithSecrets,
} from '@briefer/database'
import { IOServer } from '../websocket/index.js'
import {
Expand Down Expand Up @@ -37,6 +39,7 @@ import { PythonExecutionError } from '../python/index.js'
import { getSqlServerSchema } from './sqlserver.js'
import { z } from 'zod'
import { splitEvery } from 'ramda'
import { createEmbedding } from '../embedding.js'

function decryptDBData(
dataSourceId: string,
Expand Down Expand Up @@ -519,19 +522,54 @@ async function _refreshDataSourceStructure(
socketServer: IOServer,
dataSource: APIDataSource
) {
const workspace = await getWorkspaceWithSecrets(
dataSource.config.data.workspaceId
)
if (!workspace) {
throw new Error(
`Failed to find Workspace(${dataSource.config.data.workspaceId}) for DataSource(${dataSource.config.data.id})`
)
}

const updateQueue = new PQueue({ concurrency: 1 })
const tables: { schema: string; table: string }[] = []
let defaultSchema = ''
const onTable: OnTable = (schema, tableName, table, defaultSchema) => {
defaultSchema = defaultSchema || schema
tables.push({ schema, table: tableName })
updateQueue.add(async () => {
const openAiApiKey = workspace.secrets?.openAiApiKey
let embedding: number[] | null = null
if (openAiApiKey) {
let ddl = `CREATE TABLE ${schema}.${tableName} (\n`
for (const c of table.columns) {
ddl += ` ${c.name} ${c.type}\n`
}
try {
embedding = await createEmbedding(
ddl,
decrypt(openAiApiKey, config().WORKSPACE_SECRETS_ENCRYPTION_KEY)
)
} catch (err) {
logger().error(
{
err,
dataSourceId: dataSource.config.data.id,
dataSourceType: dataSource.config.type,
schemaName: schema,
tableName,
},
'Failed to create embedding'
)
}
}

await prisma().dataSourceSchema.update({
where: { id: dataSource.structure.id },
data: { defaultSchema },
})

await prisma().dataSourceSchemaTable.upsert({
const dbSchemaTable = await prisma().dataSourceSchemaTable.upsert({
where: {
dataSourceSchemaId_schema_name: {
dataSourceSchemaId: dataSource.structure.id,
Expand All @@ -550,6 +588,11 @@ async function _refreshDataSourceStructure(
},
})

if (embedding) {
await prisma()
.$queryRaw`UPDATE "DataSourceSchemaTable" SET embedding = ${embedding} WHERE id = ${dbSchemaTable.id}::uuid`
}

broadcastDataSource(socketServer, dataSource)
broadcastDataSourceSchemaTableUpdate(
socketServer,
Expand Down
55 changes: 55 additions & 0 deletions apps/api/src/embedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import { createHash } from 'crypto'
import prisma from '@briefer/database'
import { OpenAI } from 'openai'
import { head } from 'ramda'
import { z } from 'zod'
import { logger } from './logger.js'
import { jsonString } from '@briefer/types'

export async function createEmbedding(input: string, openAiApiKey: string) {
const model = 'text-embedding-3-small'
const inputChecksum = createHash('sha256').update(input).digest('hex')
const rawExistingEmbeddings = await prisma()
.$queryRaw`SELECT embedding::text FROM "EmbeddingCache" WHERE "inputChecksum" = ${inputChecksum} AND model = ${model}`

const parsedExistingEmbeddings = z
.array(z.object({ embedding: jsonString.pipe(z.array(z.number())) }))
.safeParse(rawExistingEmbeddings)
if (parsedExistingEmbeddings.success) {
const embedding = head(parsedExistingEmbeddings.data)?.embedding
if (embedding) {
return embedding
}
} else {
logger().error(
{
err: parsedExistingEmbeddings.error,
},
'Failed to parse existing embeddings'
)
}

const openai = new OpenAI({ apiKey: openAiApiKey })
const embeddingResponse = await openai.embeddings.create({
model,
input,
})
const embedding = head(embeddingResponse.data)?.embedding
if (!embedding) {
throw new Error('OpenAI did not return any embeddings')
}

try {
await prisma()
.$queryRaw`INSERT INTO "EmbeddingCache" ("inputChecksum", model, embedding) VALUES (${inputChecksum}, ${model}, ${embedding}::vector) ON CONFLICT DO NOTHING`
} catch (err) {
logger().error(
{
err,
},
'Failed to insert embedding into cache'
)
}

return embedding
}
60 changes: 55 additions & 5 deletions apps/api/src/yjs/v2/executor/ai/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import {
updateSQLAISuggestions,
} from '@briefer/editor'
import * as Y from 'yjs'
import {
import prisma, {
listDataSources,
getWorkspaceWithSecrets,
DataSource,
decrypt,
} from '@briefer/database'
import { logger } from '../../../../logger.js'
import { sqlEditStreamed } from '../../../../ai-api.js'
Expand All @@ -24,7 +25,10 @@ import {
fetchDataSourceStructureFromCache,
listSchemaTables,
} from '../../../../datasources/structure.js'
import { DataSourceStructureStateV3 } from '@briefer/types'
import { DataSourceStructureStateV3, uuidSchema } from '@briefer/types'
import { createEmbedding } from '../../../../embedding.js'
import { config } from '../../../../config/index.js'
import { z } from 'zod'

async function editWithAI(
workspaceId: string,
Expand Down Expand Up @@ -70,7 +74,12 @@ async function editWithAI(
dataSource.data.id,
dataSource.type
)
const tableInfo = await tableInfoFromStructure(dataSource, structure)
const tableInfo = await retrieveTableInfoForQuestion(
dataSource,
structure,
instructions,
workspace?.secrets?.openAiApiKey ?? null
)

event(assistantModelId)

Expand Down Expand Up @@ -110,6 +119,47 @@ async function editWithAI(
)
}

async function retrieveTableInfoForQuestion(
datasource: DataSource,
structure: DataSourceStructureStateV3 | null,
question: string,
openAiApiKey: string | null
): Promise<string | null> {
if (!structure || !openAiApiKey) {
return tableInfoFromStructure(datasource, structure)
}

const questionEmbedding = await createEmbedding(
question,
decrypt(openAiApiKey, config().WORKSPACE_SECRETS_ENCRYPTION_KEY)
)

const raw = await prisma()
.$queryRaw`SELECT t.id, t.embedding <=> ${questionEmbedding}::vector AS distance FROM "DataSourceSchemaTable" t INNER JOIN "DataSourceSchema" s ON s.id = t."dataSourceSchemaId" WHERE s.id = ${structure.id}::uuid ORDER BY distance LIMIT 30`

const result = z.array(z.object({ id: uuidSchema })).parse(raw)

const tables = await prisma().dataSourceSchemaTable.findMany({
where: {
id: { in: result.map((r) => r.id) },
},
})

let tableInfo = ''
for (const table of tables) {
tableInfo += `${table.schema}.${table.name}\n`
const columns = z
.array(z.object({ name: z.string(), type: z.string() }))
.parse(table.columns)
for (const column of columns) {
tableInfo += `${column.name} ${column.type}\n`
}
tableInfo += '\n'
}

return tableInfo.trim()
}

async function tableInfoFromStructure(
config: DataSource,
structure: DataSourceStructureStateV3 | null
Expand All @@ -121,8 +171,8 @@ async function tableInfoFromStructure(
let result = ''
for await (const schemaTable of listSchemaTables([{ config, structure }])) {
result += `${schemaTable.schemaName}.${schemaTable.tableName}\n`
for (const columns of schemaTable.table.columns) {
result += `${columns.name} ${columns.type}\n`
for (const column of schemaTable.table.columns) {
result += `${column.name} ${column.type}\n`
}
result += '\n'
}
Expand Down
7 changes: 7 additions & 0 deletions dev/postgresql/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM postgres:16

# Install pgvector
RUN apt-get update && apt-get install -y postgresql-16-pgvector

# Clean up
RUN apt-get clean && rm -rf /var/lib/apt/lists/*
4 changes: 3 additions & 1 deletion docker-compose.dev.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
services:
postgres:
image: postgres
build:
context: dev/postgresql
dockerfile: Dockerfile
ports:
- '5432:5432'
environment:
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
postgres:
image: postgres
image: pgvector/pgvector:pg16
environment:
POSTGRES_DB: 'briefer'
POSTGRES_USER: ${POSTGRES_USERNAME:?error}
Expand Down
9 changes: 9 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ RUN apt-get update && apt-get install -y \
python3-venv \
python3-dev \
postgresql \
postgresql-common \
postgresql-contrib \
nginx \
sudo \
Expand All @@ -28,6 +29,14 @@ RUN apt-get update && apt-get install -y \
alien && \
apt-get clean && rm -rf /var/lib/apt/lists/*

# Set environment variable to skip the user prompt
ENV DEBIAN_FRONTEND=noninteractive
# Enable the PostgreSQL APT repository
RUN /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y
# Install the pgvector extension for PostgreSQL
RUN apt-get update && apt-get install -y postgresql-15-pgvector && \
apt-get clean && rm -rf /var/lib/apt/lists/*

#### ORACLE INSTANT CLIENT ####
ARG TARGETARCH

Expand Down
2 changes: 2 additions & 0 deletions docker/init_db.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ if [ ! -f /var/lib/postgresql/data/.init ]; then
psql -U postgres -c "GRANT ALL PRIVILEGES ON DATABASE briefer TO briefer;"
touch /var/lib/postgresql/data/.init
fi

psql -U postgres -d briefer -c "CREATE EXTENSION IF NOT EXISTS vector;"
6 changes: 4 additions & 2 deletions docker/setup/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ def generate_jupyter_config():
}
with open(fpath, "w") as f:
json.dump(cfg, f, indent=4)
os.chown(fpath, pwd.getpwnam("jupyteruser").pw_uid, grp.getgrnam("jupyteruser").gr_gid)
os.chmod(fpath, 0o700)

# recursively chown jupyteruser home directory
os.system(f"chown -R jupyteruser:jupyteruser /home/jupyteruser")
os.system(f"chmod -R 700 /home/jupyteruser")

def setup_jupyter():
generate_jupyter_config()
Expand Down
2 changes: 1 addition & 1 deletion docker/supervisord.conf
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ directory=/app/setup/
user=jupyteruser
environment=HOME="/home/jupyteruser",USER="jupyteruser"
autostart=true
autorestart=false
autorestart=true
startsecs=0
priority=4
stdout_logfile=/dev/stdout
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE EXTENSION IF NOT EXISTS vector;

-- AlterTable
ALTER TABLE "DataSourceSchemaTable" ADD COLUMN "embedding" vector;
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-- CreateTable
CREATE TABLE "EmbeddingCache" (
"id" UUID NOT NULL DEFAULT gen_random_uuid(),
"inputChecksum" TEXT NOT NULL,
"model" TEXT NOT NULL,
"embedding" vector NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,

CONSTRAINT "EmbeddingCache_pkey" PRIMARY KEY ("id")
);

-- CreateIndex
CREATE UNIQUE INDEX "unique_inputChecksum_model" ON "EmbeddingCache"("inputChecksum", "model");
13 changes: 13 additions & 0 deletions packages/database/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ model DataSourceSchemaTable {
schema String
columns Json
dataSourceSchemaId String @db.Uuid
embedding Unsupported("vector")?
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
Expand Down Expand Up @@ -671,3 +672,15 @@ model OnboardingTutorial {
@@unique([workspaceId])
}

model EmbeddingCache {
id String @id @default(dbgenerated("gen_random_uuid()")) @db.Uuid
inputChecksum String
model String
embedding Unsupported("vector")
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
@@unique([inputChecksum, model], map: "unique_inputChecksum_model")
}
Loading

0 comments on commit f210ac4

Please sign in to comment.