From fb2bb022415bcceaf3b456c3059a4f19ef7df01e Mon Sep 17 00:00:00 2001 From: yuiseki Date: Mon, 8 Jan 2024 09:22:42 +0900 Subject: [PATCH] refactor --- src/app/api/ai/inner/route.ts | 35 +++------------ src/app/api/ai/suggests/route.ts | 45 +++++++++++++++---- src/app/api/ai/surface/route.ts | 29 ++---------- .../SuggestByCurrentLocation/index.tsx | 3 +- src/utils/langchain/chains/suggest/prompt.ts | 3 +- src/utils/nextPostJson.ts | 2 +- .../trident/convertChatHistoryToLines.ts | 17 ------- src/utils/trident/parsePastMessagesToLines.ts | 24 ++++++++-- 8 files changed, 71 insertions(+), 87 deletions(-) delete mode 100644 src/utils/trident/convertChatHistoryToLines.ts diff --git a/src/app/api/ai/inner/route.ts b/src/app/api/ai/inner/route.ts index a7e0bb83..828ae71d 100644 --- a/src/app/api/ai/inner/route.ts +++ b/src/app/api/ai/inner/route.ts @@ -1,8 +1,8 @@ import { NextResponse } from "next/server"; -import { OpenAI, OpenAIChat } from "langchain/llms/openai"; +import { OpenAIChat } from "langchain/llms/openai"; import { loadTridentInnerChain } from "@/utils/langchain/chains/inner"; -import { BaseChatMessage } from "langchain/schema"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { parsePastMessagesToLines } from "@/utils/trident/parsePastMessagesToLines"; export async function POST(request: Request) { console.log("----- ----- -----"); @@ -14,33 +14,10 @@ export async function POST(request: Request) { console.log("pastMessagesJsonString"); console.debug(pastMessagesJsonString); - let chatHistory = undefined; - let chatHistoryLines = ""; - if (pastMessagesJsonString && pastMessagesJsonString !== "undefined") { - const pastMessages: Array<{ - type: string; - id: string[]; - kwargs: { content: string }; - }> = JSON.parse(pastMessagesJsonString); - - chatHistory = - pastMessages && - pastMessages - .map((message) => { - switch (message.id[2]) { - case "HumanMessage": - return `${message.kwargs.content}`; - case "AIMessage": - return null; - default: - return null; - } - }) - .filter((v) => v); - if (chatHistory) { - chatHistoryLines = chatHistory.join("\n").replace("\n\n", "\n"); - } - } + const chatHistoryLines = parsePastMessagesToLines( + pastMessagesJsonString, + true + ); console.log(""); console.log("chatHistoryLines:"); diff --git a/src/app/api/ai/suggests/route.ts b/src/app/api/ai/suggests/route.ts index 6b09aeb8..c52632a5 100644 --- a/src/app/api/ai/suggests/route.ts +++ b/src/app/api/ai/suggests/route.ts @@ -2,10 +2,42 @@ import { NextResponse } from "next/server"; import { OpenAIChat } from "langchain/llms/openai"; import { loadTridentSuggestChain } from "@/utils/langchain/chains/suggest"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { parsePastMessagesToLines } from "@/utils/trident/parsePastMessagesToLines"; export async function POST(request: Request) { - const res = await request.json(); - const query = res.query; + console.log("----- ----- -----"); + console.log("----- start suggests -----"); + + const reqJson = await request.json(); + const lang = reqJson.lang; + const location = reqJson.location; + const pastMessagesJsonString = reqJson.pastMessages; + + console.log("pastMessagesJsonString"); + console.debug(pastMessagesJsonString); + + const chatHistoryLines = parsePastMessagesToLines( + pastMessagesJsonString, + true + ); + + let input = ""; + + if (lang) { + input = `Primary language of user: ${lang}\n`; + } + + if (location) { + input += `Current location of user: ${location}\n`; + } + + if (chatHistoryLines) { + input += `\nChat history:\n${chatHistoryLines}`; + } + + console.log(""); + console.log("input:"); + console.log(input); let embeddings: OpenAIEmbeddings; let llm: OpenAIChat; @@ -27,19 +59,16 @@ export async function POST(request: Request) { } const chain = await loadTridentSuggestChain({ embeddings, llm }); - const result = await chain.call({ input: query }); + const result = await chain.call({ input }); - console.log("----- ----- -----"); - console.log("----- start suggest -----"); - console.log("Human:", query); - console.log("AI:", result.text); + console.log(""); + console.log("Suggests:\n", result.text); console.log(""); console.log("----- end suggest -----"); console.log("----- ----- -----"); return NextResponse.json({ - query: query, suggests: result.text, }); } diff --git a/src/app/api/ai/surface/route.ts b/src/app/api/ai/surface/route.ts index d6782278..591d2403 100644 --- a/src/app/api/ai/surface/route.ts +++ b/src/app/api/ai/surface/route.ts @@ -5,6 +5,7 @@ import { loadTridentSurfaceChain } from "@/utils/langchain/chains/surface"; // using openai import { OpenAIChat } from "langchain/llms/openai"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { parsePastMessagesToChatHistory } from "@/utils/trident/parsePastMessagesToChatHistory"; export async function POST(request: Request) { console.log("----- ----- -----"); @@ -14,36 +15,12 @@ export async function POST(request: Request) { const query = reqJson.query; const pastMessagesJsonString = reqJson.pastMessages; - let chatHistory = undefined; - - if (pastMessagesJsonString && pastMessagesJsonString !== "undefined") { - const pastMessages: Array<{ - type: string; - id: string[]; - kwargs: { content: string }; - }> = JSON.parse(pastMessagesJsonString); - - const chatHistoryMessages = pastMessages.map((message) => { - if (message.kwargs.content) { - switch (message.id[2]) { - case "HumanMessage": - return new HumanMessage(message.kwargs.content); - case "AIMessage": - return new AIMessage(message.kwargs.content); - default: - return new HumanMessage(""); - } - } else { - return new HumanMessage(""); - } - }); - chatHistory = new ChatMessageHistory(chatHistoryMessages); - } + const chatHistory = parsePastMessagesToChatHistory(pastMessagesJsonString); const memory = new BufferMemory({ returnMessages: true, memoryKey: "history", - chatHistory: chatHistory, + chatHistory, }); let embeddings: OpenAIEmbeddings; diff --git a/src/components/InputSuggest/SuggestByCurrentLocation/index.tsx b/src/components/InputSuggest/SuggestByCurrentLocation/index.tsx index 0ce6b499..0c365314 100644 --- a/src/components/InputSuggest/SuggestByCurrentLocation/index.tsx +++ b/src/components/InputSuggest/SuggestByCurrentLocation/index.tsx @@ -43,7 +43,8 @@ export const SuggestByCurrentLocation: React.FC<{ } const thisEffect = async () => { const resJson = await nextPostJsonWithCache("/api/ai/suggests", { - query: address, + lang: window.navigator.language, + location: address, }); console.log(resJson.suggests); if (!resJson.suggests) { diff --git a/src/utils/langchain/chains/suggest/prompt.ts b/src/utils/langchain/chains/suggest/prompt.ts index 08af1a34..7ce3d6a0 100644 --- a/src/utils/langchain/chains/suggest/prompt.ts +++ b/src/utils/langchain/chains/suggest/prompt.ts @@ -11,7 +11,8 @@ export const tridentSuggestExampleList: Array<{ output: string; }> = [ { - input: "台東区, 東京都, 日本", + input: `Primary language of user: ja +Current location of user: 台東区, 東京都, 日本`, output: `台東区の地図を表示して 東京都の地図を表示して 日本の地図を表示して`, diff --git a/src/utils/nextPostJson.ts b/src/utils/nextPostJson.ts index da782920..fec70a1a 100644 --- a/src/utils/nextPostJson.ts +++ b/src/utils/nextPostJson.ts @@ -19,7 +19,7 @@ export const nextPostJsonWithCache = async ( const md5 = new Md5(); md5.appendStr(`${url}\n${bodyJsonString}`); const hash = md5.end(); - const key = `trident-cache_2024-01-07_${hash}`; + const key = `trident-cache_2024-01-08_${hash}`; const unixtime = Math.floor(new Date().getTime() / 1000); const fetchAndCache = async () => { diff --git a/src/utils/trident/convertChatHistoryToLines.ts b/src/utils/trident/convertChatHistoryToLines.ts deleted file mode 100644 index 2a5b31f2..00000000 --- a/src/utils/trident/convertChatHistoryToLines.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { ChatMessageHistory } from "langchain/memory"; - -export const convertChatHistoryToLines = async ( - chatHistory?: ChatMessageHistory, - excludeAI?: boolean -) => { - let chatHistoryLines = ""; - if (!chatHistory) { - return chatHistoryLines; - } - const messages = await chatHistory.getMessages(); - messages.forEach((message) => { - console.log(message); - chatHistoryLines += message.lc_kwargs.content + "\n"; - }); - return chatHistoryLines; -}; diff --git a/src/utils/trident/parsePastMessagesToLines.ts b/src/utils/trident/parsePastMessagesToLines.ts index 5774cbe4..e23cbb90 100644 --- a/src/utils/trident/parsePastMessagesToLines.ts +++ b/src/utils/trident/parsePastMessagesToLines.ts @@ -1,7 +1,12 @@ import { ChatMessageHistory } from "langchain/memory"; -export const parsePastMessagesToLines = (pastMessagesJsonString: string) => { +export const parsePastMessagesToLines = ( + pastMessagesJsonString: string, + onlyHuman?: boolean +) => { let chatHistory: Array = []; + let chatHistoryLines = ""; + if (pastMessagesJsonString && pastMessagesJsonString !== "undefined") { const pastMessages: Array<{ type: string; @@ -15,14 +20,25 @@ export const parsePastMessagesToLines = (pastMessagesJsonString: string) => { .map((message) => { switch (message.id[2]) { case "HumanMessage": - return `Human: ${message.kwargs.content}`; + if (onlyHuman) { + return message.kwargs.content; + } else { + return `Human: ${message.kwargs.content}`; + } case "AIMessage": - return `AI: ${message.kwargs.content}`; + if (onlyHuman) { + return null; + } else { + return `AI: ${message.kwargs.content}`; + } default: return null; } }) .filter((v) => v); } - return chatHistory; + if (chatHistory) { + chatHistoryLines = chatHistory.join("\n").replace("\n\n", "\n"); + } + return chatHistoryLines; };