Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
yuiseki committed Jan 8, 2024
1 parent 3ad3d4f commit fb2bb02
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 87 deletions.
35 changes: 6 additions & 29 deletions src/app/api/ai/inner/route.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { NextResponse } from "next/server";
import { OpenAI, OpenAIChat } from "langchain/llms/openai";
import { OpenAIChat } from "langchain/llms/openai";
import { loadTridentInnerChain } from "@/utils/langchain/chains/inner";
import { BaseChatMessage } from "langchain/schema";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { parsePastMessagesToLines } from "@/utils/trident/parsePastMessagesToLines";

export async function POST(request: Request) {
console.log("----- ----- -----");
Expand All @@ -14,33 +14,10 @@ export async function POST(request: Request) {
console.log("pastMessagesJsonString");
console.debug(pastMessagesJsonString);

let chatHistory = undefined;
let chatHistoryLines = "";
if (pastMessagesJsonString && pastMessagesJsonString !== "undefined") {
const pastMessages: Array<{
type: string;
id: string[];
kwargs: { content: string };
}> = JSON.parse(pastMessagesJsonString);

chatHistory =
pastMessages &&
pastMessages
.map((message) => {
switch (message.id[2]) {
case "HumanMessage":
return `${message.kwargs.content}`;
case "AIMessage":
return null;
default:
return null;
}
})
.filter((v) => v);
if (chatHistory) {
chatHistoryLines = chatHistory.join("\n").replace("\n\n", "\n");
}
}
const chatHistoryLines = parsePastMessagesToLines(
pastMessagesJsonString,
true
);

console.log("");
console.log("chatHistoryLines:");
Expand Down
45 changes: 37 additions & 8 deletions src/app/api/ai/suggests/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,42 @@ import { NextResponse } from "next/server";
import { OpenAIChat } from "langchain/llms/openai";
import { loadTridentSuggestChain } from "@/utils/langchain/chains/suggest";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { parsePastMessagesToLines } from "@/utils/trident/parsePastMessagesToLines";

export async function POST(request: Request) {
const res = await request.json();
const query = res.query;
console.log("----- ----- -----");
console.log("----- start suggests -----");

const reqJson = await request.json();
const lang = reqJson.lang;
const location = reqJson.location;
const pastMessagesJsonString = reqJson.pastMessages;

console.log("pastMessagesJsonString");
console.debug(pastMessagesJsonString);

const chatHistoryLines = parsePastMessagesToLines(
pastMessagesJsonString,
true
);

let input = "";

if (lang) {
input = `Primary language of user: ${lang}\n`;
}

if (location) {
input += `Current location of user: ${location}\n`;
}

if (chatHistoryLines) {
input += `\nChat history:\n${chatHistoryLines}`;
}

console.log("");
console.log("input:");
console.log(input);

let embeddings: OpenAIEmbeddings;
let llm: OpenAIChat;
Expand All @@ -27,19 +59,16 @@ export async function POST(request: Request) {
}

const chain = await loadTridentSuggestChain({ embeddings, llm });
const result = await chain.call({ input: query });
const result = await chain.call({ input });

console.log("----- ----- -----");
console.log("----- start suggest -----");
console.log("Human:", query);
console.log("AI:", result.text);
console.log("");
console.log("Suggests:\n", result.text);
console.log("");

console.log("----- end suggest -----");
console.log("----- ----- -----");

return NextResponse.json({
query: query,
suggests: result.text,
});
}
29 changes: 3 additions & 26 deletions src/app/api/ai/surface/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { loadTridentSurfaceChain } from "@/utils/langchain/chains/surface";
// using openai
import { OpenAIChat } from "langchain/llms/openai";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { parsePastMessagesToChatHistory } from "@/utils/trident/parsePastMessagesToChatHistory";

export async function POST(request: Request) {
console.log("----- ----- -----");
Expand All @@ -14,36 +15,12 @@ export async function POST(request: Request) {
const query = reqJson.query;
const pastMessagesJsonString = reqJson.pastMessages;

let chatHistory = undefined;

if (pastMessagesJsonString && pastMessagesJsonString !== "undefined") {
const pastMessages: Array<{
type: string;
id: string[];
kwargs: { content: string };
}> = JSON.parse(pastMessagesJsonString);

const chatHistoryMessages = pastMessages.map((message) => {
if (message.kwargs.content) {
switch (message.id[2]) {
case "HumanMessage":
return new HumanMessage(message.kwargs.content);
case "AIMessage":
return new AIMessage(message.kwargs.content);
default:
return new HumanMessage("");
}
} else {
return new HumanMessage("");
}
});
chatHistory = new ChatMessageHistory(chatHistoryMessages);
}
const chatHistory = parsePastMessagesToChatHistory(pastMessagesJsonString);

const memory = new BufferMemory({
returnMessages: true,
memoryKey: "history",
chatHistory: chatHistory,
chatHistory,
});

let embeddings: OpenAIEmbeddings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ export const SuggestByCurrentLocation: React.FC<{
}
const thisEffect = async () => {
const resJson = await nextPostJsonWithCache("/api/ai/suggests", {
query: address,
lang: window.navigator.language,
location: address,
});
console.log(resJson.suggests);
if (!resJson.suggests) {
Expand Down
3 changes: 2 additions & 1 deletion src/utils/langchain/chains/suggest/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ export const tridentSuggestExampleList: Array<{
output: string;
}> = [
{
input: "台東区, 東京都, 日本",
input: `Primary language of user: ja
Current location of user: 台東区, 東京都, 日本`,
output: `台東区の地図を表示して
東京都の地図を表示して
日本の地図を表示して`,
Expand Down
2 changes: 1 addition & 1 deletion src/utils/nextPostJson.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export const nextPostJsonWithCache = async (
const md5 = new Md5();
md5.appendStr(`${url}\n${bodyJsonString}`);
const hash = md5.end();
const key = `trident-cache_2024-01-07_${hash}`;
const key = `trident-cache_2024-01-08_${hash}`;
const unixtime = Math.floor(new Date().getTime() / 1000);

const fetchAndCache = async () => {
Expand Down
17 changes: 0 additions & 17 deletions src/utils/trident/convertChatHistoryToLines.ts

This file was deleted.

24 changes: 20 additions & 4 deletions src/utils/trident/parsePastMessagesToLines.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import { ChatMessageHistory } from "langchain/memory";

export const parsePastMessagesToLines = (pastMessagesJsonString: string) => {
export const parsePastMessagesToLines = (
pastMessagesJsonString: string,
onlyHuman?: boolean
) => {
let chatHistory: Array<string | null> = [];
let chatHistoryLines = "";

if (pastMessagesJsonString && pastMessagesJsonString !== "undefined") {
const pastMessages: Array<{
type: string;
Expand All @@ -15,14 +20,25 @@ export const parsePastMessagesToLines = (pastMessagesJsonString: string) => {
.map((message) => {
switch (message.id[2]) {
case "HumanMessage":
return `Human: ${message.kwargs.content}`;
if (onlyHuman) {
return message.kwargs.content;
} else {
return `Human: ${message.kwargs.content}`;
}
case "AIMessage":
return `AI: ${message.kwargs.content}`;
if (onlyHuman) {
return null;
} else {
return `AI: ${message.kwargs.content}`;
}
default:
return null;
}
})
.filter((v) => v);
}
return chatHistory;
if (chatHistory) {
chatHistoryLines = chatHistory.join("\n").replace("\n\n", "\n");
}
return chatHistoryLines;
};

1 comment on commit fb2bb02

@vercel
Copy link

@vercel vercel bot commented on fb2bb02 Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.