From 24e4c2c1831fdfb54d3329af8a751c15eb4b1581 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Sun, 29 Sep 2024 19:19:29 -0700 Subject: [PATCH] feat(core): Allow more types of inputs in prompt templates (#6894) --- langchain-core/src/prompts/prompt.ts | 3 +- langchain-core/src/prompts/template.ts | 11 +++- langchain-core/src/prompts/tests/chat.test.ts | 64 +++++++++++++++++++ .../src/prompts/tests/prompt.test.ts | 13 ++++ 4 files changed, 87 insertions(+), 4 deletions(-) diff --git a/langchain-core/src/prompts/prompt.ts b/langchain-core/src/prompts/prompt.ts index 998715895141..2ef8e9263ae9 100644 --- a/langchain-core/src/prompts/prompt.ts +++ b/langchain-core/src/prompts/prompt.ts @@ -84,7 +84,8 @@ type ExtractTemplateParamsRecursive< export type ParamsFromFString = { [Key in | ExtractTemplateParamsRecursive[number] - | (string & Record)]: string; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | (string & Record)]: any; }; export type ExtractedFStringParams< diff --git a/langchain-core/src/prompts/template.ts b/langchain-core/src/prompts/template.ts index 2188ba1f870f..e6ddb98c087e 100644 --- a/langchain-core/src/prompts/template.ts +++ b/langchain-core/src/prompts/template.ts @@ -106,17 +106,22 @@ export const parseMustache = (template: string) => { return mustacheTemplateToNodes(parsed); }; -export const interpolateFString = (template: string, values: InputValues) => - parseFString(template).reduce((res, node) => { +export const interpolateFString = (template: string, values: InputValues) => { + return parseFString(template).reduce((res, node) => { if (node.type === "variable") { if (node.name in values) { - return res + values[node.name]; + const stringValue = + typeof values[node.name] === "string" + ? values[node.name] + : JSON.stringify(values[node.name]); + return res + stringValue; } throw new Error(`(f-string) Missing value for input ${node.name}`); } return res + node.text; }, ""); +}; export const interpolateMustache = (template: string, values: InputValues) => { configureMustache(); diff --git a/langchain-core/src/prompts/tests/chat.test.ts b/langchain-core/src/prompts/tests/chat.test.ts index b3fc38958ddc..26b44f1a4f44 100644 --- a/langchain-core/src/prompts/tests/chat.test.ts +++ b/langchain-core/src/prompts/tests/chat.test.ts @@ -15,6 +15,7 @@ import { ChatMessage, FunctionMessage, } from "../../messages/index.js"; +import { Document } from "../../documents/document.js"; function createChatPromptTemplate() { const systemPrompt = new PromptTemplate({ @@ -129,6 +130,23 @@ test("Test fromTemplate", async () => { ]); }); +test("Test fromTemplate", async () => { + const chatPrompt = ChatPromptTemplate.fromTemplate("Hello {foo}, I'm {bar}"); + expect(chatPrompt.inputVariables).toEqual(["foo", "bar"]); + expect( + ( + await chatPrompt.invoke({ + foo: ["barbar"], + bar: [new Document({ pageContent: "bar" })], + }) + ).toChatMessages() + ).toEqual([ + new HumanMessage( + `Hello ["barbar"], I'm [{"pageContent":"bar","metadata":{}}]` + ), + ]); +}); + test("Test fromMessages", async () => { const systemPrompt = new PromptTemplate({ template: "Here's some context: {context}", @@ -155,6 +173,34 @@ test("Test fromMessages", async () => { ]); }); +test("Test fromMessages with non-string inputs", async () => { + const systemPrompt = new PromptTemplate({ + template: "Here's some context: {context}", + inputVariables: ["context"], + }); + const userPrompt = new PromptTemplate({ + template: "Hello {foo}, I'm {bar}", + inputVariables: ["foo", "bar"], + }); + // TODO: Fix autocomplete for the fromMessages method + const chatPrompt = ChatPromptTemplate.fromMessages([ + new SystemMessagePromptTemplate(systemPrompt), + new HumanMessagePromptTemplate(userPrompt), + ]); + expect(chatPrompt.inputVariables).toEqual(["context", "foo", "bar"]); + const messages = await chatPrompt.formatPromptValue({ + context: [new Document({ pageContent: "bar" })], + foo: "Foo", + bar: "Bar", + }); + expect(messages.toChatMessages()).toEqual([ + new SystemMessage( + `Here's some context: [{"pageContent":"bar","metadata":{}}]` + ), + new HumanMessage("Hello Foo, I'm Bar"), + ]); +}); + test("Test fromMessages with a variety of ways to declare prompt messages", async () => { const systemPrompt = new PromptTemplate({ template: "Here's some context: {context}", @@ -306,6 +352,24 @@ test("Test MessagesPlaceholder not optional", async () => { ); }); +test("Test MessagesPlaceholder not optional with invalid input should throw", async () => { + const prompt = new MessagesPlaceholder({ + variableName: "foo", + }); + const badInput = [new Document({ pageContent: "barbar", metadata: {} })]; + await expect( + prompt.formatMessages({ + foo: [new Document({ pageContent: "barbar", metadata: {} })], + }) + ).rejects.toThrow( + `Field "foo" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages or coerceable values as input.\n\nReceived value: ${JSON.stringify( + badInput, + null, + 2 + )}\n\nAdditional message: Unable to coerce message from array: only human, AI, or system message coercion is currently supported.` + ); +}); + test("Test MessagesPlaceholder shorthand in a chat prompt template should throw for invalid syntax", async () => { expect(() => ChatPromptTemplate.fromMessages([["placeholder", "foo"]]) diff --git a/langchain-core/src/prompts/tests/prompt.test.ts b/langchain-core/src/prompts/tests/prompt.test.ts index 3d7e30e16436..e5902c447401 100644 --- a/langchain-core/src/prompts/tests/prompt.test.ts +++ b/langchain-core/src/prompts/tests/prompt.test.ts @@ -1,5 +1,6 @@ import { expect, test } from "@jest/globals"; import { PromptTemplate } from "../prompt.js"; +import { Document } from "../../documents/document.js"; test("Test using partial", async () => { const prompt = new PromptTemplate({ @@ -26,6 +27,18 @@ test("Test fromTemplate", async () => { ).toBe("foobaz"); }); +test("Test fromTemplate with a non-string value", async () => { + const prompt = PromptTemplate.fromTemplate("{foo}{bar}"); + expect( + ( + await prompt.invoke({ + foo: ["barbar"], + bar: [new Document({ pageContent: "bar" })], + }) + ).value + ).toBe(`["barbar"][{"pageContent":"bar","metadata":{}}]`); +}); + test("Test fromTemplate with escaped strings", async () => { const prompt = PromptTemplate.fromTemplate("{{foo}}{{bar}}"); expect(await prompt.format({ unused: "eee" })).toBe("{foo}{bar}");