Skip to content

Commit

Permalink
(EAI-602): Refactor preprocessing to support retrieval evals (#552)
Browse files Browse the repository at this point in the history
* add distinct retrieveRelevantContent module + braintrust tracing

* move retrieval conf

* retrieval eval working in correct location

* Add avg score

* remove unused imports
  • Loading branch information
mongodben authored Nov 11, 2024
1 parent 201aa43 commit 8eb9ec9
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 49 deletions.
41 changes: 25 additions & 16 deletions packages/chatbot-server-mongodb-public/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ export const verifiedAnswerConfig = {
},
};
export const retrievalConfig = {
model: OPENAI_EMBEDDING_DEPLOYMENT,
preprocessorLlm: OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT,
embeddingModel: OPENAI_EMBEDDING_DEPLOYMENT,
findNearestNeighborsOptions: {
k: 5,
path: "embedding",
Expand All @@ -113,7 +114,7 @@ export const retrievalConfig = {

export const embedder = makeOpenAiEmbedder({
openAiClient,
deployment: retrievalConfig.model,
deployment: retrievalConfig.embeddingModel,
backoffOptions: {
numOfAttempts: 3,
maxDelay: 5000,
Expand Down Expand Up @@ -157,21 +158,29 @@ export const preprocessorOpenAiClient = wrapOpenAI(
})
);

export const generateUserPrompt = makeVerifiedAnswerGenerateUserPrompt({
findVerifiedAnswer,
onVerifiedAnswerFound: (verifiedAnswer) => {
return {
...verifiedAnswer,
references: verifiedAnswer.references.map(addReferenceSourceType),
};
},
onNoVerifiedAnswerFound: makeStepBackRagGenerateUserPrompt({
openAiClient: preprocessorOpenAiClient,
model: OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT,
findContent,
numPrecedingMessagesToInclude: 6,
export const generateUserPrompt = wrapTraced(
makeVerifiedAnswerGenerateUserPrompt({
findVerifiedAnswer,
onVerifiedAnswerFound: (verifiedAnswer) => {
return {
...verifiedAnswer,
references: verifiedAnswer.references.map(addReferenceSourceType),
};
},
onNoVerifiedAnswerFound: wrapTraced(
makeStepBackRagGenerateUserPrompt({
openAiClient: preprocessorOpenAiClient,
model: retrievalConfig.preprocessorLlm,
findContent,
numPrecedingMessagesToInclude: 6,
}),
{ name: "makeStepBackRagGenerateUserPrompt" }
),
}),
});
{
name: "generateUserPrompt",
}
);

export const mongodb = new MongoClient(MONGODB_CONNECTION_URI);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ function cleanPath(path: string) {
}

function getCleanPath(maybeUrl: string) {
let out = "";
try {
const url = new URL(maybeUrl);
return cleanPath(url.pathname);
out = cleanPath(url.pathname);
} catch (error) {
// If it's not a valid URL, return the input string as is
return maybeUrl;
out = cleanPath(maybeUrl);
}
return out;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ import {
GenerateUserPromptFuncReturnValue,
Message,
UserMessage,
updateFrontMatter,
} from "mongodb-chatbot-server";
import { OpenAI } from "mongodb-rag-core/openai";
import { makeStepBackUserQuery } from "./makeStepBackUserQuery";
import { stripIndents } from "common-tags";
import { strict as assert } from "assert";
import { logRequest } from "../utils";
import { makeMongoDbReferences } from "./makeMongoDbReferences";
import { extractMongoDbMetadataFromUserMessage } from "./extractMongoDbMetadataFromUserMessage";
import { userMessageMongoDbGuardrail } from "./userMessageMongoDbGuardrail";
import { retrieveRelevantContent } from "./retrieveRelevantContent";

interface MakeStepBackGenerateUserPromptProps {
openAiClient: OpenAI;
model: string;
Expand Down Expand Up @@ -108,20 +108,16 @@ export const makeStepBackRagGenerateUserPrompt = ({
metadataForQuery.mongoDbProductName = metadata.mongoDbProduct;
}

const { transformedUserQuery } = await makeStepBackUserQuery({
openAiClient,
model,
messages: precedingMessagesToInclude,
userMessageText: updateFrontMatter(userMessageText, metadataForQuery),
});
logRequest({
reqId,
message: `Step back query: ${transformedUserQuery}`,
});
const { transformedUserQuery, content, queryEmbedding, searchQuery } =
await retrieveRelevantContent({
findContent,
metadataForQuery,
model,
openAiClient,
precedingMessagesToInclude,
userMessageText,
});

const { content, queryEmbedding } = await findContent({
query: updateFrontMatter(transformedUserQuery, metadataForQuery),
});
logRequest({
reqId,
message: `Found ${content.length} results for query: ${content
Expand All @@ -137,8 +133,12 @@ export const makeStepBackRagGenerateUserPrompt = ({
url: c.url,
score: c.score,
})),
customData,
preprocessedContent: transformedUserQuery,
customData: {
...customData,
...metadata,
searchQuery,
transformedUserQuery,
},
} satisfies UserMessage;
if (content.length === 0) {
return {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import "dotenv/config";
import { Eval, EvalCase, EvalScorer, EvalTask } from "braintrust";
import { MongoDbTag } from "./mongoDbMetadata";
import fs from "fs";
import path from "path";
import { strict as assert } from "assert";
import { averagePrecisionAtK } from "./eval/scorers/averagePrecisionAtK";
import { getConversationsEvalCasesFromYaml } from "./eval/getConversationEvalCasesFromYaml";
import { ExtractMongoDbMetadataFunction } from "./processors/extractMongoDbMetadataFromUserMessage";
import { findContent, retrievalConfig } from "./config";
import { fuzzyLinkMatch } from "./eval/fuzzyLinkMatch";
import { binaryNdcgAtK } from "./eval/scorers/binaryNdcgAtK";
import { f1AtK } from "./eval/scorers/f1AtK";
import { precisionAtK } from "./eval/scorers/precisionAtK";
import { recallAtK } from "./eval/scorers/recallAtK";
import "dotenv/config";
import {
retrievalConfig,
findContent,
preprocessorOpenAiClient,
} from "../config";
import { fuzzyLinkMatch } from "../eval/fuzzyLinkMatch";
import { getConversationsEvalCasesFromYaml } from "../eval/getConversationEvalCasesFromYaml";
import { averagePrecisionAtK } from "../eval/scorers/averagePrecisionAtK";
import { binaryNdcgAtK } from "../eval/scorers/binaryNdcgAtK";
import { f1AtK } from "../eval/scorers/f1AtK";
import { precisionAtK } from "../eval/scorers/precisionAtK";
import { recallAtK } from "../eval/scorers/recallAtK";
import { MongoDbTag } from "../mongoDbMetadata";
import {
extractMongoDbMetadataFromUserMessage,
ExtractMongoDbMetadataFunction,
} from "./extractMongoDbMetadataFromUserMessage";
import { retrieveRelevantContent } from "./retrieveRelevantContent";

interface RetrievalEvalCaseInput {
query: string;
Expand Down Expand Up @@ -58,18 +66,33 @@ const simpleConversationEvalTask: EvalTask<
RetrievalEvalCaseInput,
RetrievalTaskOutput
> = async function (data) {
const results = await findContent({ query: data.query });
const metadataForQuery = await extractMongoDbMetadataFromUserMessage({
openAiClient: preprocessorOpenAiClient,
model: retrievalConfig.preprocessorLlm,
userMessageText: data.query,
});
const results = await retrieveRelevantContent({
userMessageText: data.query,
model: retrievalConfig.preprocessorLlm,
openAiClient: preprocessorOpenAiClient,
findContent,
metadataForQuery,
});

return {
results: results.content.map((c) => ({
url: c.url,
content: c.text,
score: c.score,
})),
extractedMetadata: metadataForQuery,
rewrittenQuery: results.transformedUserQuery,
searchString: results.searchQuery,
};
};

async function getConversationRetrievalEvalData() {
const basePath = path.resolve(__dirname, "..", "evalCases");
const basePath = path.resolve(__dirname, "..", "..", "evalCases");
const includedLinksConversations = getConversationsEvalCasesFromYaml(
fs.readFileSync(
path.resolve(basePath, "included_links_conversations.yml"),
Expand Down Expand Up @@ -163,8 +186,17 @@ const RetrievedLengthOverK: RetrievalEvalScorer = async (args) => {
};
};

const AvgSearchScore: RetrievalEvalScorer = async (args) => {
return {
name: "AvgSearchScore",
score:
args.output.results.reduce((acc, r) => acc + r.score, 0) /
args.output.results.length,
};
};

Eval("mongodb-chatbot-retrieval", {
experimentName: `mongodb-chatbot-retrieval-latest?model=${retrievalConfig.model}&@K=${k}&minScore=${retrievalConfig.findNearestNeighborsOptions.minScore}`,
experimentName: `mongodb-chatbot-retrieval-latest?model=${retrievalConfig.embeddingModel}&@K=${k}&minScore=${retrievalConfig.findNearestNeighborsOptions.minScore}`,
metadata: {
description: "Evaluates quality of chatbot retrieval system",
retrievalConfig,
Expand All @@ -175,6 +207,7 @@ Eval("mongodb-chatbot-retrieval", {
scores: [
BinaryNdcgAtK,
F1AtK,
AvgSearchScore,
RetrievedLengthOverK,
AveragePrecisionAtK,
PrecisionAtK,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import { FindContentFunc, updateFrontMatter } from "mongodb-rag-core";
import { retrieveRelevantContent } from "./retrieveRelevantContent";
import { makeMockOpenAIToolCall } from "../test/mockOpenAi";
import { StepBackUserQueryMongoDbFunction } from "./makeStepBackUserQuery";
import { OpenAI } from "mongodb-rag-core/openai";

jest.mock("mongodb-rag-core/openai", () =>
makeMockOpenAIToolCall({ transformedUserQuery: "transformedUserQuery" })
);
describe("retrieveRelevantContent", () => {
const model = "model";
const funcRes = {
transformedUserQuery: "transformedUserQuery",
} satisfies StepBackUserQueryMongoDbFunction;
const fakeEmbedding = [1, 2, 3];

const fakeContentBase = {
embedding: fakeEmbedding,
score: 1,
url: "url",
tokenCount: 3,
sourceName: "sourceName",
updated: new Date(),
};
const fakeFindContent: FindContentFunc = async ({ query }) => {
return {
content: [
{
text: "all about " + query,
...fakeContentBase,
},
],
queryEmbedding: fakeEmbedding,
};
};

const mockToolCallOpenAi = new OpenAI({
apiKey: "apiKey",
});
const argsBase = {
openAiClient: mockToolCallOpenAi,
model,
userMessageText: "something",
findContent: fakeFindContent,
};
const metadataForQuery = {
programmingLanguage: "javascript",
mongoDbProduct: "Aggregation Framework",
};
it("should return content, queryEmbedding, transformedUserQuery, searchQuery with metadata", async () => {
const res = await retrieveRelevantContent({
...argsBase,
metadataForQuery,
});
expect(res).toEqual({
content: [
{
text: expect.any(String),
...fakeContentBase,
},
],
queryEmbedding: fakeEmbedding,
transformedUserQuery: funcRes.transformedUserQuery,
searchQuery: updateFrontMatter(
funcRes.transformedUserQuery,
metadataForQuery
),
});
});
it("should return content, queryEmbedding, transformedUserQuery, searchQuery without", async () => {
const res = await retrieveRelevantContent(argsBase);
expect(res).toEqual({
content: [
{
text: expect.any(String),
...fakeContentBase,
},
],
queryEmbedding: fakeEmbedding,
transformedUserQuery: funcRes.transformedUserQuery,
searchQuery: funcRes.transformedUserQuery,
});
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { makeStepBackUserQuery } from "./makeStepBackUserQuery";
import { FindContentFunc, Message } from "mongodb-rag-core";
import { updateFrontMatter } from "mongodb-rag-core";
import { OpenAI } from "mongodb-rag-core/openai";

export const retrieveRelevantContent = async function ({
openAiClient,
model,
precedingMessagesToInclude,
userMessageText,
metadataForQuery,
findContent,
}: {
openAiClient: OpenAI;
model: string;
precedingMessagesToInclude?: Message[];
userMessageText: string;
metadataForQuery?: Record<string, unknown>;
findContent: FindContentFunc;
}) {
const { transformedUserQuery } = await makeStepBackUserQuery({
openAiClient,
model,
messages: precedingMessagesToInclude,
userMessageText: metadataForQuery
? updateFrontMatter(userMessageText, metadataForQuery)
: userMessageText,
});

const searchQuery = metadataForQuery
? updateFrontMatter(transformedUserQuery, metadataForQuery)
: transformedUserQuery;

const { content, queryEmbedding } = await findContent({
query: searchQuery,
});

return { content, queryEmbedding, transformedUserQuery, searchQuery };
};

0 comments on commit 8eb9ec9

Please sign in to comment.