Skip to content

Commit

Permalink
feat: Add knowledge base selection to OpenAI chat
Browse files Browse the repository at this point in the history
  • Loading branch information
n4ze3m committed Sep 6, 2024
1 parent 7bdefc7 commit 9d66ccc
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 72 deletions.
85 changes: 41 additions & 44 deletions server/src/handlers/api/v1/openai/chat.handler.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { FastifyRequest, FastifyReply } from "fastify";
import type { OpenaiRequestType } from "./type"
import type { OpenaiRequestType } from "./type";
import { getModelInfo } from "../../../../utils/get-model-info";
import { embeddings } from "../../../../utils/embeddings";
import { Document } from "langchain/document";
Expand All @@ -8,49 +8,56 @@ import { DialoqbaseHybridRetrival } from "../../../../utils/hybrid";
import { DialoqbaseVectorStore } from "../../../../utils/store";
import { createChatModel } from "../bot/playground/chat.service";
import { createChain } from "../../../../chain";
import { openaiNonStreamResponse, openaiStreamResponse } from "./openai-response";
import {
openaiNonStreamResponse,
openaiStreamResponse,
} from "./openai-response";
import { groupOpenAiMessages } from "./other";
import { nextTick } from "../../../../utils/nextTick";


export const createChatCompletionHandler = async (
request: FastifyRequest<OpenaiRequestType>,
reply: FastifyReply
) => {
try {
const {
model,
messages
} = request.body;
const { model, messages } = request.body;

const prisma = request.server.prisma;

let knowledge_base_ids: string[] = [];

const kb = request.body?.tools?.find(
(e) => e.type === "knowledge_base" && e.value.length > 0
);
if (kb) {
knowledge_base_ids = kb.value;
}
console.log(knowledge_base_ids)
const bot = await prisma.bot.findFirst({
where: {
OR: [
{
id: model
id: model,
},
{
publicId: model
}
publicId: model,
},
],
user_id: request.user.is_admin ? undefined : request.user.user_id,
},
})
});

if (!bot) {
return reply.status(404).send({
error: {
message: "Bot not found",
type: "not_found",
param: "model",
code: "bot_not_found"
}
code: "bot_not_found",
},
});
}


const embeddingInfo = await getModelInfo({
prisma,
model: bot.embedding,
Expand All @@ -63,12 +70,11 @@ export const createChatCompletionHandler = async (
message: "Embedding not found",
type: "not_found",
param: "embedding",
code: "embedding_not_found"
}
code: "embedding_not_found",
},
});
}


const embeddingModel = embeddings(
embeddingInfo.model_provider!.toLowerCase(),
embeddingInfo.model_id,
Expand All @@ -87,8 +93,8 @@ export const createChatCompletionHandler = async (
message: "Model not found",
type: "not_found",
param: "model",
code: "model_not_found"
}
code: "model_not_found",
},
});
}

Expand All @@ -100,6 +106,7 @@ export const createChatCompletionHandler = async (
retriever = new DialoqbaseHybridRetrival(embeddingModel, {
botId: bot.id,
sourceId: null,
knowledge_base_ids,
callbacks: [
{
handleRetrieverEnd(documents) {
Expand All @@ -114,11 +121,12 @@ export const createChatCompletionHandler = async (
{
botId: bot.id,
sourceId: null,
knowledge_base_ids,

}
);

retriever = vectorstore.asRetriever({
});
retriever = vectorstore.asRetriever({});
}

const streamedModel = createChatModel(
Expand All @@ -140,48 +148,37 @@ export const createChatCompletionHandler = async (
if (!request.body.stream) {
const res = await chain.invoke({
question: messages[messages.length - 1].content,
chat_history: groupOpenAiMessages(
messages
),
})

chat_history: groupOpenAiMessages(messages),
});

return reply.status(200).send(openaiNonStreamResponse(
res,
bot.name
))
return reply.status(200).send(openaiNonStreamResponse(res, bot.name));
}

const stream = await chain.stream({
question: messages[messages.length - 1].content,
chat_history: groupOpenAiMessages(
messages
),
})
chat_history: groupOpenAiMessages(messages),
});
reply.raw.setHeader("Content-Type", "text/event-stream");

for await (const token of stream) {
reply.sse({
data: openaiStreamResponse(
token || "",
bot.name
)
data: openaiStreamResponse(token || "", bot.name),
});
}
reply.sse({
data: "[DONE]\n\n"
})
data: "[DONE]\n\n",
});
await nextTick();
return reply.raw.end();
} catch (error) {
console.log(error)
console.log(error);
return reply.status(500).send({
error: {
message: error.message,
type: "internal_server_error",
param: null,
code: "internal_server_error"
}
code: "internal_server_error",
},
});
}
}
};
4 changes: 4 additions & 0 deletions server/src/handlers/api/v1/openai/type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,9 @@ export interface OpenaiRequestType {
model: string;
stream: boolean;
temperature: number;
tools?: {
type?: "knowledge_base",
value?: string[]
}[]
}
}
18 changes: 18 additions & 0 deletions server/src/schema/api/v1/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ export const createChatCompletionSchema: FastifySchema = {
},
temperature: {
type: "number"
},
tools: {
type: "array",
items: {
type: "object",
required: ["type"],
properties: {
type: {
type: "string"
},
value: {
type: "array",
items: {
type: "string"
}
}
}
}
}
}
}
Expand Down
65 changes: 50 additions & 15 deletions server/src/utils/hybrid.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Document } from "langchain/document";
import { PrismaClient } from "@prisma/client";
import { Prisma, PrismaClient } from "@prisma/client";
import { Embeddings } from "langchain/embeddings/base";
import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers";
import { CallbackManagerForRetrieverRun, Callbacks } from "langchain/callbacks";
Expand All @@ -8,6 +8,7 @@ const prisma = new PrismaClient();
export interface DialoqbaseLibArgs extends BaseRetrieverInput {
botId: string;
sourceId: string | null;
knowledge_base_ids?: string[];
}

interface SearchEmbeddingsResponse {
Expand All @@ -30,14 +31,36 @@ export class DialoqbaseHybridRetrival extends BaseRetriever {
embeddings: Embeddings;
similarityK = 5;
keywordK = 4;
knowledge_base_ids: string[];

constructor(embeddings: Embeddings, args: DialoqbaseLibArgs) {
super(args);
this.botId = args.botId;
this.sourceId = args.sourceId;
this.embeddings = embeddings;
this.knowledge_base_ids = args.knowledge_base_ids || [];
}
async similaritySearchWithSelectedKBs(
query: number[],
k: number,
knowledgeBaseIds: string[]
) {
const vector = `[${query?.join(",")}]`;
const results = await prisma.$queryRaw`
SELECT "sourceId", "content", "metadata",
(embedding <=> ${vector}::vector) AS distance
FROM "BotDocument"
WHERE "sourceId" IN (${Prisma.join(knowledgeBaseIds)})
ORDER BY distance ASC
LIMIT ${k}
`
return results as {
sourceId: string;
content: string;
metadata: object;
distance: number;
}[];
}

protected async similaritySearch(
query: string,
k: number,
Expand All @@ -53,20 +76,32 @@ export class DialoqbaseHybridRetrival extends BaseRetriever {
id: bot_id,
},
});
const data = await prisma.$queryRaw`
SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${k}::int)
`;
let result: (number | Document<object>)[][];
const match_count = botInfo?.noOfDocumentsToRetrieve || k;

if (this.knowledge_base_ids && this.knowledge_base_ids.length > 0) {
const data = await this.similaritySearchWithSelectedKBs(embeddedQuery, match_count, this.knowledge_base_ids);
result = data.map((resp) => [
new Document({
metadata: resp.metadata,
pageContent: resp.content,
}),
1 - resp.distance,
]);
} else {
const data = await prisma.$queryRaw`
SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${match_count}::int)
`;
result = (data as SearchEmbeddingsResponse[]).map((resp) => [
new Document({
metadata: resp.metadata,
pageContent: resp.content,
}),
resp.similarity,
]);
}


const result: [Document, number, number][] = (
data as SearchEmbeddingsResponse[]
).map((resp) => [
new Document({
metadata: resp.metadata,
pageContent: resp.content,
}),
resp.similarity * 10,
resp.id,
]);
let internetSearchResults = [];
if (botInfo.internetSearchEnabled) {
internetSearchResults = await searchInternet(this.embeddings, {
Expand Down
Loading

0 comments on commit 9d66ccc

Please sign in to comment.