Skip to content

Commit

Permalink
refactor: simplify chat history handling and update vector store inte…
Browse files Browse the repository at this point in the history
…gration
  • Loading branch information
yuiseki committed Nov 9, 2024
1 parent 76b473f commit 5a35202
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 51 deletions.
6 changes: 1 addition & 5 deletions src/app/api/ai/inner/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ export async function POST(request: Request) {

const reqJson = await request.json();
const pastMessagesJsonString = reqJson.pastMessages;

const chatHistoryLines = parsePastMessagesToLines(
pastMessagesJsonString,
true
);
const chatHistoryLines = pastMessagesJsonString;

console.log("");
console.log("chatHistoryLines:");
Expand Down
33 changes: 14 additions & 19 deletions src/app/api/ai/surface/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ import { loadTridentSurfaceChain } from "@/utils/langchain/chains/surface";
// using openai
import { ChatOpenAI } from "@langchain/openai";
import { OpenAIEmbeddings } from "@langchain/openai";
import { parsePastMessagesToChatHistory } from "@/utils/trident/parsePastMessagesToChatHistory";
import { BufferMemory } from "langchain/memory";
import { parsePastMessagesToLines } from "@/utils/trident/parsePastMessagesToLines";
import { MemoryVectorStore } from "langchain/vectorstores/memory";

export async function POST(request: Request) {
console.log("----- ----- -----");
Expand All @@ -13,14 +13,12 @@ export async function POST(request: Request) {
const reqJson = await request.json();
const query = reqJson.query;
const pastMessagesJsonString = reqJson.pastMessages;
let chatHistoryLines = pastMessagesJsonString;
chatHistoryLines = chatHistoryLines + "\nHuman: " + query;

const chatHistory = parsePastMessagesToChatHistory(pastMessagesJsonString);

const memory = new BufferMemory({
returnMessages: true,
memoryKey: "history",
chatHistory,
});
console.log("");
console.log("chatHistoryLines:");
console.log(chatHistoryLines);

let llm: ChatOpenAI;
let embeddings: OpenAIEmbeddings;
Expand All @@ -41,27 +39,24 @@ export async function POST(request: Request) {
embeddings = new OpenAIEmbeddings();
}

const vectorStore = new MemoryVectorStore(embeddings);

const surfaceChain = await loadTridentSurfaceChain({
llm,
embeddings,
memory,
vectorStore,
});
const surfaceResult = await surfaceChain.call({ input: query });
const surfaceResult = await surfaceChain.invoke({ input: chatHistoryLines });

console.log("Human:", query);
console.log("AI:", surfaceResult.response);
console.log("AI:", surfaceResult.text);
console.log("");

const history = await memory.chatHistory.getMessages();
// debug用
//console.log(history);

console.log("----- end surface -----");
console.log("----- ----- -----");

return NextResponse.json({
query: query,
surface: surfaceResult.response,
history: history,
surface: surfaceResult.text,
history: chatHistoryLines,
});
}
24 changes: 6 additions & 18 deletions src/utils/langchain/chains/surface/index.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,16 @@
import { Embeddings } from "@langchain/core/embeddings";
import { loadTridentSurfacePrompt } from "./prompt";
import { BaseLanguageModel } from "@langchain/core/language_models/base";
import { BaseMemory } from "@langchain/core/memory";
import { RunnableSequence } from "@langchain/core/runnables";
import { BufferMemory } from "langchain/memory";
import { ConversationChain } from "langchain/chains";
import { VectorStore } from "@langchain/core/vectorstores";

export const loadTridentSurfaceChain = async ({
embeddings,
llm,
memory,
vectorStore,
}: {
embeddings: Embeddings;
llm: BaseLanguageModel;
memory?: BaseMemory;
}): Promise<ConversationChain> => {
if (memory === undefined) {
memory = new BufferMemory();
}
const prompt = await loadTridentSurfacePrompt(embeddings);
const chain = new ConversationChain({
llm: llm,
prompt: prompt,
memory: memory,
});
vectorStore: VectorStore;
}): Promise<RunnableSequence> => {
const prompt = await loadTridentSurfacePrompt(vectorStore);
const chain = RunnableSequence.from([prompt, llm]);
return chain;
};
24 changes: 15 additions & 9 deletions src/utils/langchain/chains/surface/prompt.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { Embeddings } from "@langchain/core/embeddings";
import { SemanticSimilarityExampleSelector } from "@langchain/core/example_selectors";
import { FewShotPromptTemplate, PromptTemplate } from "@langchain/core/prompts";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { VectorStore } from "@langchain/core/vectorstores";

export const tridentSurfaceExampleList: Array<{
input: string;
Expand All @@ -17,6 +16,11 @@ export const tridentSurfaceExampleList: Array<{
output:
"了解しました。OpenStreetMapのデータに基づいてニューヨーク市を表示する地図を作成しています。しばらくお待ちください……",
},
{
input: "台東区を表示して",
output:
"了解しました。OpenStreetMapのデータに基づいて台東区を表示する地図を作成しています。しばらくお待ちください……",
},
{
input: "显示纽约地图",
output: "知道了。我正在生成基于OpenStreetMap数据的纽约市地图。请稍等……",
Expand All @@ -42,10 +46,9 @@ You will always reply according to the following rules:
### Examples: ###`;

export const loadTridentSurfacePrompt = async (embeddings: Embeddings) => {
const memoryVectorStore = new MemoryVectorStore(embeddings);
export const loadTridentSurfacePrompt = async (vectorStore: VectorStore) => {
const exampleSelector = new SemanticSimilarityExampleSelector({
vectorStore: memoryVectorStore,
vectorStore: vectorStore,
k: 3,
inputKeys: ["input"],
});
Expand All @@ -69,10 +72,13 @@ AI:
suffix: `
### Current conversation: ###
{history}
Human: {input}
AI: `,
inputVariables: ["history", "input"],
Human:
{input}
AI:
`,
inputVariables: ["input"],
});

return dynamicPrompt;
};

0 comments on commit 5a35202

Please sign in to comment.