From 5a35202078d7b9364199a0c3ecae7fd4f6dfc6a9 Mon Sep 17 00:00:00 2001 From: yuiseki Date: Sat, 9 Nov 2024 21:53:11 +0900 Subject: [PATCH] refactor: simplify chat history handling and update vector store integration --- src/app/api/ai/inner/route.ts | 6 +--- src/app/api/ai/surface/route.ts | 33 +++++++++----------- src/utils/langchain/chains/surface/index.ts | 24 ++++---------- src/utils/langchain/chains/surface/prompt.ts | 24 ++++++++------ 4 files changed, 36 insertions(+), 51 deletions(-) diff --git a/src/app/api/ai/inner/route.ts b/src/app/api/ai/inner/route.ts index cfbf2020..78e503c8 100644 --- a/src/app/api/ai/inner/route.ts +++ b/src/app/api/ai/inner/route.ts @@ -18,11 +18,7 @@ export async function POST(request: Request) { const reqJson = await request.json(); const pastMessagesJsonString = reqJson.pastMessages; - - const chatHistoryLines = parsePastMessagesToLines( - pastMessagesJsonString, - true - ); + const chatHistoryLines = pastMessagesJsonString; console.log(""); console.log("chatHistoryLines:"); diff --git a/src/app/api/ai/surface/route.ts b/src/app/api/ai/surface/route.ts index 66c75c00..e43e6fcd 100644 --- a/src/app/api/ai/surface/route.ts +++ b/src/app/api/ai/surface/route.ts @@ -3,8 +3,8 @@ import { loadTridentSurfaceChain } from "@/utils/langchain/chains/surface"; // using openai import { ChatOpenAI } from "@langchain/openai"; import { OpenAIEmbeddings } from "@langchain/openai"; -import { parsePastMessagesToChatHistory } from "@/utils/trident/parsePastMessagesToChatHistory"; -import { BufferMemory } from "langchain/memory"; +import { parsePastMessagesToLines } from "@/utils/trident/parsePastMessagesToLines"; +import { MemoryVectorStore } from "langchain/vectorstores/memory"; export async function POST(request: Request) { console.log("----- ----- -----"); @@ -13,14 +13,12 @@ export async function POST(request: Request) { const reqJson = await request.json(); const query = reqJson.query; const pastMessagesJsonString = reqJson.pastMessages; + let chatHistoryLines = pastMessagesJsonString; + chatHistoryLines = chatHistoryLines + "\nHuman: " + query; - const chatHistory = parsePastMessagesToChatHistory(pastMessagesJsonString); - - const memory = new BufferMemory({ - returnMessages: true, - memoryKey: "history", - chatHistory, - }); + console.log(""); + console.log("chatHistoryLines:"); + console.log(chatHistoryLines); let llm: ChatOpenAI; let embeddings: OpenAIEmbeddings; @@ -41,27 +39,24 @@ export async function POST(request: Request) { embeddings = new OpenAIEmbeddings(); } + const vectorStore = new MemoryVectorStore(embeddings); + const surfaceChain = await loadTridentSurfaceChain({ llm, - embeddings, - memory, + vectorStore, }); - const surfaceResult = await surfaceChain.call({ input: query }); + const surfaceResult = await surfaceChain.invoke({ input: chatHistoryLines }); console.log("Human:", query); - console.log("AI:", surfaceResult.response); + console.log("AI:", surfaceResult.text); console.log(""); - const history = await memory.chatHistory.getMessages(); - // debug用 - //console.log(history); - console.log("----- end surface -----"); console.log("----- ----- -----"); return NextResponse.json({ query: query, - surface: surfaceResult.response, - history: history, + surface: surfaceResult.text, + history: chatHistoryLines, }); } diff --git a/src/utils/langchain/chains/surface/index.ts b/src/utils/langchain/chains/surface/index.ts index 13ee59e3..1428af95 100644 --- a/src/utils/langchain/chains/surface/index.ts +++ b/src/utils/langchain/chains/surface/index.ts @@ -1,28 +1,16 @@ -import { Embeddings } from "@langchain/core/embeddings"; import { loadTridentSurfacePrompt } from "./prompt"; import { BaseLanguageModel } from "@langchain/core/language_models/base"; -import { BaseMemory } from "@langchain/core/memory"; import { RunnableSequence } from "@langchain/core/runnables"; -import { BufferMemory } from "langchain/memory"; -import { ConversationChain } from "langchain/chains"; +import { VectorStore } from "@langchain/core/vectorstores"; export const loadTridentSurfaceChain = async ({ - embeddings, llm, - memory, + vectorStore, }: { - embeddings: Embeddings; llm: BaseLanguageModel; - memory?: BaseMemory; -}): Promise => { - if (memory === undefined) { - memory = new BufferMemory(); - } - const prompt = await loadTridentSurfacePrompt(embeddings); - const chain = new ConversationChain({ - llm: llm, - prompt: prompt, - memory: memory, - }); + vectorStore: VectorStore; +}): Promise => { + const prompt = await loadTridentSurfacePrompt(vectorStore); + const chain = RunnableSequence.from([prompt, llm]); return chain; }; diff --git a/src/utils/langchain/chains/surface/prompt.ts b/src/utils/langchain/chains/surface/prompt.ts index 0dfc17ca..fbbf3dc8 100644 --- a/src/utils/langchain/chains/surface/prompt.ts +++ b/src/utils/langchain/chains/surface/prompt.ts @@ -1,7 +1,6 @@ -import { Embeddings } from "@langchain/core/embeddings"; import { SemanticSimilarityExampleSelector } from "@langchain/core/example_selectors"; import { FewShotPromptTemplate, PromptTemplate } from "@langchain/core/prompts"; -import { MemoryVectorStore } from "langchain/vectorstores/memory"; +import { VectorStore } from "@langchain/core/vectorstores"; export const tridentSurfaceExampleList: Array<{ input: string; @@ -17,6 +16,11 @@ export const tridentSurfaceExampleList: Array<{ output: "了解しました。OpenStreetMapのデータに基づいてニューヨーク市を表示する地図を作成しています。しばらくお待ちください……", }, + { + input: "台東区を表示して", + output: + "了解しました。OpenStreetMapのデータに基づいて台東区を表示する地図を作成しています。しばらくお待ちください……", + }, { input: "显示纽约地图", output: "知道了。我正在生成基于OpenStreetMap数据的纽约市地图。请稍等……", @@ -42,10 +46,9 @@ You will always reply according to the following rules: ### Examples: ###`; -export const loadTridentSurfacePrompt = async (embeddings: Embeddings) => { - const memoryVectorStore = new MemoryVectorStore(embeddings); +export const loadTridentSurfacePrompt = async (vectorStore: VectorStore) => { const exampleSelector = new SemanticSimilarityExampleSelector({ - vectorStore: memoryVectorStore, + vectorStore: vectorStore, k: 3, inputKeys: ["input"], }); @@ -69,10 +72,13 @@ AI: suffix: ` ### Current conversation: ### -{history} -Human: {input} -AI: `, - inputVariables: ["history", "input"], +Human: +{input} + +AI: +`, + inputVariables: ["input"], }); + return dynamicPrompt; };