diff --git a/packages/chatbot-server-mongodb-public/src/config.ts b/packages/chatbot-server-mongodb-public/src/config.ts index 4af4d79e..8fe15254 100644 --- a/packages/chatbot-server-mongodb-public/src/config.ts +++ b/packages/chatbot-server-mongodb-public/src/config.ts @@ -102,7 +102,8 @@ export const verifiedAnswerConfig = { }, }; export const retrievalConfig = { - model: OPENAI_EMBEDDING_DEPLOYMENT, + preprocessorLlm: OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT, + embeddingModel: OPENAI_EMBEDDING_DEPLOYMENT, findNearestNeighborsOptions: { k: 5, path: "embedding", @@ -113,7 +114,7 @@ export const retrievalConfig = { export const embedder = makeOpenAiEmbedder({ openAiClient, - deployment: retrievalConfig.model, + deployment: retrievalConfig.embeddingModel, backoffOptions: { numOfAttempts: 3, maxDelay: 5000, @@ -157,21 +158,29 @@ export const preprocessorOpenAiClient = wrapOpenAI( }) ); -export const generateUserPrompt = makeVerifiedAnswerGenerateUserPrompt({ - findVerifiedAnswer, - onVerifiedAnswerFound: (verifiedAnswer) => { - return { - ...verifiedAnswer, - references: verifiedAnswer.references.map(addReferenceSourceType), - }; - }, - onNoVerifiedAnswerFound: makeStepBackRagGenerateUserPrompt({ - openAiClient: preprocessorOpenAiClient, - model: OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT, - findContent, - numPrecedingMessagesToInclude: 6, +export const generateUserPrompt = wrapTraced( + makeVerifiedAnswerGenerateUserPrompt({ + findVerifiedAnswer, + onVerifiedAnswerFound: (verifiedAnswer) => { + return { + ...verifiedAnswer, + references: verifiedAnswer.references.map(addReferenceSourceType), + }; + }, + onNoVerifiedAnswerFound: wrapTraced( + makeStepBackRagGenerateUserPrompt({ + openAiClient: preprocessorOpenAiClient, + model: retrievalConfig.preprocessorLlm, + findContent, + numPrecedingMessagesToInclude: 6, + }), + { name: "makeStepBackRagGenerateUserPrompt" } + ), }), -}); + { + name: "generateUserPrompt", + } +); export const mongodb = new MongoClient(MONGODB_CONNECTION_URI); diff --git a/packages/chatbot-server-mongodb-public/src/eval/fuzzyLinkMatch.ts b/packages/chatbot-server-mongodb-public/src/eval/fuzzyLinkMatch.ts index faa5172c..df78aafb 100644 --- a/packages/chatbot-server-mongodb-public/src/eval/fuzzyLinkMatch.ts +++ b/packages/chatbot-server-mongodb-public/src/eval/fuzzyLinkMatch.ts @@ -22,11 +22,13 @@ function cleanPath(path: string) { } function getCleanPath(maybeUrl: string) { + let out = ""; try { const url = new URL(maybeUrl); - return cleanPath(url.pathname); + out = cleanPath(url.pathname); } catch (error) { // If it's not a valid URL, return the input string as is - return maybeUrl; + out = cleanPath(maybeUrl); } + return out; } diff --git a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackRagGenerateUserPrompt.ts b/packages/chatbot-server-mongodb-public/src/processors/makeStepBackRagGenerateUserPrompt.ts index 87dbf1ab..f0cff7ed 100644 --- a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackRagGenerateUserPrompt.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/makeStepBackRagGenerateUserPrompt.ts @@ -5,16 +5,16 @@ import { GenerateUserPromptFuncReturnValue, Message, UserMessage, - updateFrontMatter, } from "mongodb-chatbot-server"; import { OpenAI } from "mongodb-rag-core/openai"; -import { makeStepBackUserQuery } from "./makeStepBackUserQuery"; import { stripIndents } from "common-tags"; import { strict as assert } from "assert"; import { logRequest } from "../utils"; import { makeMongoDbReferences } from "./makeMongoDbReferences"; import { extractMongoDbMetadataFromUserMessage } from "./extractMongoDbMetadataFromUserMessage"; import { userMessageMongoDbGuardrail } from "./userMessageMongoDbGuardrail"; +import { retrieveRelevantContent } from "./retrieveRelevantContent"; + interface MakeStepBackGenerateUserPromptProps { openAiClient: OpenAI; model: string; @@ -108,20 +108,16 @@ export const makeStepBackRagGenerateUserPrompt = ({ metadataForQuery.mongoDbProductName = metadata.mongoDbProduct; } - const { transformedUserQuery } = await makeStepBackUserQuery({ - openAiClient, - model, - messages: precedingMessagesToInclude, - userMessageText: updateFrontMatter(userMessageText, metadataForQuery), - }); - logRequest({ - reqId, - message: `Step back query: ${transformedUserQuery}`, - }); + const { transformedUserQuery, content, queryEmbedding, searchQuery } = + await retrieveRelevantContent({ + findContent, + metadataForQuery, + model, + openAiClient, + precedingMessagesToInclude, + userMessageText, + }); - const { content, queryEmbedding } = await findContent({ - query: updateFrontMatter(transformedUserQuery, metadataForQuery), - }); logRequest({ reqId, message: `Found ${content.length} results for query: ${content @@ -137,8 +133,12 @@ export const makeStepBackRagGenerateUserPrompt = ({ url: c.url, score: c.score, })), - customData, - preprocessedContent: transformedUserQuery, + customData: { + ...customData, + ...metadata, + searchQuery, + transformedUserQuery, + }, } satisfies UserMessage; if (content.length === 0) { return { diff --git a/packages/chatbot-server-mongodb-public/src/retrieval.eval.ts b/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.eval.ts similarity index 69% rename from packages/chatbot-server-mongodb-public/src/retrieval.eval.ts rename to packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.eval.ts index a7e9f1d6..eee0f1e3 100644 --- a/packages/chatbot-server-mongodb-public/src/retrieval.eval.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.eval.ts @@ -1,18 +1,26 @@ +import "dotenv/config"; import { Eval, EvalCase, EvalScorer, EvalTask } from "braintrust"; -import { MongoDbTag } from "./mongoDbMetadata"; import fs from "fs"; import path from "path"; import { strict as assert } from "assert"; -import { averagePrecisionAtK } from "./eval/scorers/averagePrecisionAtK"; -import { getConversationsEvalCasesFromYaml } from "./eval/getConversationEvalCasesFromYaml"; -import { ExtractMongoDbMetadataFunction } from "./processors/extractMongoDbMetadataFromUserMessage"; -import { findContent, retrievalConfig } from "./config"; -import { fuzzyLinkMatch } from "./eval/fuzzyLinkMatch"; -import { binaryNdcgAtK } from "./eval/scorers/binaryNdcgAtK"; -import { f1AtK } from "./eval/scorers/f1AtK"; -import { precisionAtK } from "./eval/scorers/precisionAtK"; -import { recallAtK } from "./eval/scorers/recallAtK"; -import "dotenv/config"; +import { + retrievalConfig, + findContent, + preprocessorOpenAiClient, +} from "../config"; +import { fuzzyLinkMatch } from "../eval/fuzzyLinkMatch"; +import { getConversationsEvalCasesFromYaml } from "../eval/getConversationEvalCasesFromYaml"; +import { averagePrecisionAtK } from "../eval/scorers/averagePrecisionAtK"; +import { binaryNdcgAtK } from "../eval/scorers/binaryNdcgAtK"; +import { f1AtK } from "../eval/scorers/f1AtK"; +import { precisionAtK } from "../eval/scorers/precisionAtK"; +import { recallAtK } from "../eval/scorers/recallAtK"; +import { MongoDbTag } from "../mongoDbMetadata"; +import { + extractMongoDbMetadataFromUserMessage, + ExtractMongoDbMetadataFunction, +} from "./extractMongoDbMetadataFromUserMessage"; +import { retrieveRelevantContent } from "./retrieveRelevantContent"; interface RetrievalEvalCaseInput { query: string; @@ -58,18 +66,33 @@ const simpleConversationEvalTask: EvalTask< RetrievalEvalCaseInput, RetrievalTaskOutput > = async function (data) { - const results = await findContent({ query: data.query }); + const metadataForQuery = await extractMongoDbMetadataFromUserMessage({ + openAiClient: preprocessorOpenAiClient, + model: retrievalConfig.preprocessorLlm, + userMessageText: data.query, + }); + const results = await retrieveRelevantContent({ + userMessageText: data.query, + model: retrievalConfig.preprocessorLlm, + openAiClient: preprocessorOpenAiClient, + findContent, + metadataForQuery, + }); + return { results: results.content.map((c) => ({ url: c.url, content: c.text, score: c.score, })), + extractedMetadata: metadataForQuery, + rewrittenQuery: results.transformedUserQuery, + searchString: results.searchQuery, }; }; async function getConversationRetrievalEvalData() { - const basePath = path.resolve(__dirname, "..", "evalCases"); + const basePath = path.resolve(__dirname, "..", "..", "evalCases"); const includedLinksConversations = getConversationsEvalCasesFromYaml( fs.readFileSync( path.resolve(basePath, "included_links_conversations.yml"), @@ -163,8 +186,17 @@ const RetrievedLengthOverK: RetrievalEvalScorer = async (args) => { }; }; +const AvgSearchScore: RetrievalEvalScorer = async (args) => { + return { + name: "AvgSearchScore", + score: + args.output.results.reduce((acc, r) => acc + r.score, 0) / + args.output.results.length, + }; +}; + Eval("mongodb-chatbot-retrieval", { - experimentName: `mongodb-chatbot-retrieval-latest?model=${retrievalConfig.model}&@K=${k}&minScore=${retrievalConfig.findNearestNeighborsOptions.minScore}`, + experimentName: `mongodb-chatbot-retrieval-latest?model=${retrievalConfig.embeddingModel}&@K=${k}&minScore=${retrievalConfig.findNearestNeighborsOptions.minScore}`, metadata: { description: "Evaluates quality of chatbot retrieval system", retrievalConfig, @@ -175,6 +207,7 @@ Eval("mongodb-chatbot-retrieval", { scores: [ BinaryNdcgAtK, F1AtK, + AvgSearchScore, RetrievedLengthOverK, AveragePrecisionAtK, PrecisionAtK, diff --git a/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.test.ts b/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.test.ts new file mode 100644 index 00000000..8cc844bd --- /dev/null +++ b/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.test.ts @@ -0,0 +1,84 @@ +import { FindContentFunc, updateFrontMatter } from "mongodb-rag-core"; +import { retrieveRelevantContent } from "./retrieveRelevantContent"; +import { makeMockOpenAIToolCall } from "../test/mockOpenAi"; +import { StepBackUserQueryMongoDbFunction } from "./makeStepBackUserQuery"; +import { OpenAI } from "mongodb-rag-core/openai"; + +jest.mock("mongodb-rag-core/openai", () => + makeMockOpenAIToolCall({ transformedUserQuery: "transformedUserQuery" }) +); +describe("retrieveRelevantContent", () => { + const model = "model"; + const funcRes = { + transformedUserQuery: "transformedUserQuery", + } satisfies StepBackUserQueryMongoDbFunction; + const fakeEmbedding = [1, 2, 3]; + + const fakeContentBase = { + embedding: fakeEmbedding, + score: 1, + url: "url", + tokenCount: 3, + sourceName: "sourceName", + updated: new Date(), + }; + const fakeFindContent: FindContentFunc = async ({ query }) => { + return { + content: [ + { + text: "all about " + query, + ...fakeContentBase, + }, + ], + queryEmbedding: fakeEmbedding, + }; + }; + + const mockToolCallOpenAi = new OpenAI({ + apiKey: "apiKey", + }); + const argsBase = { + openAiClient: mockToolCallOpenAi, + model, + userMessageText: "something", + findContent: fakeFindContent, + }; + const metadataForQuery = { + programmingLanguage: "javascript", + mongoDbProduct: "Aggregation Framework", + }; + it("should return content, queryEmbedding, transformedUserQuery, searchQuery with metadata", async () => { + const res = await retrieveRelevantContent({ + ...argsBase, + metadataForQuery, + }); + expect(res).toEqual({ + content: [ + { + text: expect.any(String), + ...fakeContentBase, + }, + ], + queryEmbedding: fakeEmbedding, + transformedUserQuery: funcRes.transformedUserQuery, + searchQuery: updateFrontMatter( + funcRes.transformedUserQuery, + metadataForQuery + ), + }); + }); + it("should return content, queryEmbedding, transformedUserQuery, searchQuery without", async () => { + const res = await retrieveRelevantContent(argsBase); + expect(res).toEqual({ + content: [ + { + text: expect.any(String), + ...fakeContentBase, + }, + ], + queryEmbedding: fakeEmbedding, + transformedUserQuery: funcRes.transformedUserQuery, + searchQuery: funcRes.transformedUserQuery, + }); + }); +}); diff --git a/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.ts b/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.ts new file mode 100644 index 00000000..a261d527 --- /dev/null +++ b/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.ts @@ -0,0 +1,39 @@ +import { makeStepBackUserQuery } from "./makeStepBackUserQuery"; +import { FindContentFunc, Message } from "mongodb-rag-core"; +import { updateFrontMatter } from "mongodb-rag-core"; +import { OpenAI } from "mongodb-rag-core/openai"; + +export const retrieveRelevantContent = async function ({ + openAiClient, + model, + precedingMessagesToInclude, + userMessageText, + metadataForQuery, + findContent, +}: { + openAiClient: OpenAI; + model: string; + precedingMessagesToInclude?: Message[]; + userMessageText: string; + metadataForQuery?: Record; + findContent: FindContentFunc; +}) { + const { transformedUserQuery } = await makeStepBackUserQuery({ + openAiClient, + model, + messages: precedingMessagesToInclude, + userMessageText: metadataForQuery + ? updateFrontMatter(userMessageText, metadataForQuery) + : userMessageText, + }); + + const searchQuery = metadataForQuery + ? updateFrontMatter(transformedUserQuery, metadataForQuery) + : transformedUserQuery; + + const { content, queryEmbedding } = await findContent({ + query: searchQuery, + }); + + return { content, queryEmbedding, transformedUserQuery, searchQuery }; +};